Esempio n. 1
0
def main():
    global args
    args = parser.parse_args()
    args.batch_size = 1  # only segment one image for experiment

    model_dir = os.path.dirname(args.dir)
    core_config_path = os.path.join(model_dir, 'configs/core.config')
    unet_config_path = os.path.join(model_dir, 'configs/unet.config')

    core_config = CoreConfig()
    core_config.read(core_config_path)
    print('Using core configuration from {}'.format(core_config_path))

    # loading Unet configuration
    unet_config = UnetConfig()
    unet_config.read(unet_config_path, args.train_image_size)
    print('Using unet configuration from {}'.format(unet_config_path))

    offset_list = core_config.offsets
    print("offsets are: {}".format(offset_list))

    # model configurations from core config
    num_classes = core_config.num_classes
    num_colors = core_config.num_colors
    num_offsets = len(core_config.offsets)
    # model configurations from unet config
    start_filters = unet_config.start_filters
    up_mode = unet_config.up_mode
    merge_mode = unet_config.merge_mode
    depth = unet_config.depth

    model = UNet(num_classes, num_offsets,
                 in_channels=num_colors, depth=depth,
                 start_filts=start_filters,
                 up_mode=up_mode,
                 merge_mode=merge_mode)

    model_path = os.path.join(model_dir, args.model)
    if os.path.isfile(model_path):
        print("=> loading checkpoint '{}'".format(model_path))
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'])
        print("loaded.")
    else:
        print("=> no checkpoint found at '{}'".format(model_path))

    testset = WaldoTestset(args.test_data, args.train_image_size,
                           job=args.job, num_jobs=args.num_jobs)
    print('Total samples in the test set: {0}'.format(len(testset)))

    dataloader = torch.utils.data.DataLoader(
        testset, num_workers=1, batch_size=args.batch_size)

    segment_dir = args.dir
    if not os.path.exists(segment_dir):
        os.makedirs(segment_dir)
    segment(dataloader, segment_dir, model, core_config)
    make_submission(segment_dir, args.csv)
Esempio n. 2
0
    def __init__(self, **wdtargs):
        """
        :param wdtargs: {shape_in = (256,256), use_l2 = True, channel_in = 3, stride = 1, kernel_size = 2, use_cst = True}
        :type wdtargs:
        """
        super(FlowEstimator, self).__init__()
        self.unet = UNet(8, 8)
        self.predictflow = nn.Conv2d(kernel_size=3, stride=1, in_channels=8, padding=1, out_channels=2)
        self.occlusion = Occlusion()

        """Another Options"""
        self.init_wdt(wdtargs)
Esempio n. 3
0
def main():
    global args, best_loss
    args = parser.parse_args()

    if args.tensorboard:
        from tensorboard_logger import configure
        print("Using tensorboard")
        configure("%s" % (args.dir))

    # loading core configuration
    c_config = CoreConfig()
    if args.core_config == '':
        print('No core config file given, using default core configuration')
    if not os.path.exists(args.core_config):
        sys.exit('Cannot find the config file: {}'.format(args.core_config))
    else:
        c_config.read(args.core_config)
        print('Using core configuration from {}'.format(args.core_config))

    # loading Unet configuration
    u_config = UnetConfig()
    if args.unet_config == '':
        print('No unet config file given, using default unet configuration')
    if not os.path.exists(args.unet_config):
        sys.exit('Cannot find the unet configuration file: {}'.format(
            args.unet_config))
    else:
        # need train_image_size for validation
        u_config.read(args.unet_config, args.train_image_size)
        print('Using unet configuration from {}'.format(args.unet_config))

    offset_list = c_config.offsets
    print("offsets are: {}".format(offset_list))

    # model configurations from core config
    num_classes = c_config.num_classes
    num_colors = c_config.num_colors
    num_offsets = len(c_config.offsets)
    # model configurations from unet config
    start_filters = u_config.start_filters
    up_mode = u_config.up_mode
    merge_mode = u_config.merge_mode
    depth = u_config.depth

    train_data = args.train_dir + '/train'
    val_data = args.train_dir + '/val'

    trainset = WaldoDataset(train_data, c_config, args.train_image_size)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              num_workers=4,
                                              batch_size=args.batch_size,
                                              shuffle=True)

    valset = WaldoDataset(val_data, c_config, args.train_image_size)
    valloader = torch.utils.data.DataLoader(valset,
                                            num_workers=4,
                                            batch_size=args.batch_size)

    NUM_TRAIN = len(trainset)
    NUM_VAL = len(valset)
    NUM_ALL = NUM_TRAIN + NUM_VAL
    print('Total samples: {0} \n'
          'Using {1} samples for training, '
          '{2} samples for validation'.format(NUM_ALL, NUM_TRAIN, NUM_VAL))

    # create model
    model = UNet(num_classes,
                 num_offsets,
                 in_channels=num_colors,
                 depth=depth,
                 start_filts=start_filters,
                 up_mode=up_mode,
                 merge_mode=merge_mode).cuda()

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # define optimizer
    # optimizer = t.optim.Adam(model.parameters(), lr=1e-3)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=args.nesterov,
                                weight_decay=args.weight_decay)

    # Train
    for epoch in range(args.start_epoch, args.epochs):
        Train(trainloader, model, optimizer, epoch)
        val_loss = Validate(valloader, model, epoch)
        is_best = val_loss < best_loss
        best_loss = min(val_loss, best_loss)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_loss,
            }, is_best)
    print('Best validation loss: ', best_loss)

    # visualize some example outputs
    outdir = '{}/imgs'.format(args.dir)
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    sample(model, valloader, outdir, c_config)
Esempio n. 4
0
    # data info
    train_set = dataset_lits2.Lits_DataSet(args.crop_size,
                                           args.batch_size,
                                           args.resize_scale,
                                           args.dataset_path,
                                           mode='train')
    val_set = dataset_lits2.Lits_DataSet(args.crop_size,
                                         args.batch_size,
                                         args.resize_scale,
                                         args.dataset_path,
                                         mode='val')
    train_loader = DataLoader(dataset=train_set, shuffle=True)
    val_loader = DataLoader(dataset=val_set, shuffle=True)
    # model info
    model = UNet(1, [32, 48, 64, 96, 128],
                 3,
                 net_mode='3d',
                 conv_block=RecombinationBlock).to(device)
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum)
    init_util.print_network(model)
    # model = nn.DataParallel(model, device_ids=[0])   # multi-GPU

    logger = logger.Logger('./output/{}'.format(args.save))
    for epoch in range(1, args.epochs + 1):
        common.adjust_learning_rate(optimizer, epoch, args)
        train(model, train_loader, optimizer, epoch, logger)
        val(model, val_loader, epoch, logger)
        torch.save(model, './output/{}/state.pkl'.format(args.save))  # 保存模型和参数
        # torch.save(model.state_dict(), PATH) 只保存参数
Esempio n. 5
0
import torch
from torchvision import transforms
from tqdm import tqdm
import utils
from torch.utils.data import DataLoader
from models.Unet import UNet
import os
import matplotlib.pyplot as plt

# hyperparams

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(n_channels=3, n_classes=4, bilinear=True)
print(model)
model.load_state_dict(torch.load('weights/Unet_e3.pth'))
model.to(device)

transform = transforms.Compose([
    utils.transforms.RandomMirror(),
    utils.transforms.ToTensor(),
    utils.transforms.Downsize(2)
])

dataset = utils.datasets.SteelDefectDataset(
    csv_file='train.csv',
    root_dir='data/severstal-steel-defect-detection',
    transform=transform)
test_loader = DataLoader(dataset, batch_size=1, shuffle=True)

for batch, data in tqdm(enumerate(test_loader),
                        total=len(test_loader),
def main():
    global args, best_loss
    args = parser.parse_args()

    if args.tensorboard:
        print("Using tensorboard")
        configure("exp/%s" % (args.name))

    if not (os.path.exists(args.train_data) and os.path.exists(args.train_data)
            and os.path.exists(args.test_data)):
        train, val, test = DataProcess(args.train_path, args.test_path, 0.9,
                                       args.img_channels)
        torch.save(train, args.train_data)
        torch.save(val, args.val_data)
        torch.save(test, args.test_data)

    s_trans = tsf.Compose([
        tsf.ToPILImage(),
        tsf.Resize((args.img_height, args.img_width)),
        tsf.ToTensor(),
    ])

    offset_list = [(1, 1), (0, -2)]

    # split the training set into training set and validation set
    trainset = Dataset(args.train_data, s_trans, offset_list, args.num_classes,
                       args.img_height, args.img_width)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              num_workers=1,
                                              batch_size=args.batch_size)

    valset = Dataset(args.val_data, s_trans, offset_list, args.num_classes,
                     args.img_height, args.img_width)
    valloader = torch.utils.data.DataLoader(valset,
                                            num_workers=1,
                                            batch_size=args.batch_size)

    # datailer = iter(trainloader)
    # img, bound, class_id = datailer.next()
    # # print img.shape, bound.shape, class_id.shape
    # torch.set_printoptions(threshold=5000)
    # print bound.shape
    # torchvision.utils.save_image(img, 'raw.png')
    # torchvision.utils.save_image(bound[:, 0:1, :, :], 'bound1.png')
    # torchvision.utils.save_image(bound[:, 1:2, :, :], 'bound2.png')
    # torchvision.utils.save_image(class_id[:, 0:1, :, :], 'class1.png')
    # torchvision.utils.save_image(class_id[:, 1:2, :, :], 'class2.png')
    # sys.exit('stop')

    NUM_TRAIN = len(trainset)
    NUM_VAL = len(valset)
    NUM_ALL = NUM_TRAIN + NUM_VAL
    print(
        'Total samples: {0} \n'
        'Using {1} samples for training, '
        '{2} samples for validation'.format(NUM_ALL, NUM_TRAIN, NUM_VAL))

    # create model
    model = UNet(args.num_classes,
                 len(offset_list),
                 in_channels=3,
                 depth=args.depth).cuda()
    # model = UNet(3, 1, len(offset_list))

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # define optimizer
    # optimizer = t.optim.Adam(model.parameters(), lr=1e-3)
    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                nesterov=args.nesterov,
                                weight_decay=args.weight_decay)

    # Train
    for epoch in range(args.start_epoch, args.epochs):
        Train(trainloader, model, optimizer, epoch)
        val_loss = Validate(valloader, model, epoch)
        is_best = val_loss < best_loss
        best_loss = min(val_loss, best_loss)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_loss,
            }, is_best)
    print 'Best validation loss: ', best_loss

    # Visualize some predicted masks on training data to get a better intuition
    # about the performance. Comment it if not necessary.
    datailer = iter(trainloader)
    img, classification, bound = datailer.next()
    torchvision.utils.save_image(img, 'imgs/raw.png')
    for i in range(len(offset_list)):
        torchvision.utils.save_image(bound[:, i:i + 1, :, :],
                                     'imgs/bound_{}.png'.format(i))
    for i in range(args.num_classes):
        torchvision.utils.save_image(classification[:, i:i + 1, :, :],
                                     'imgs/class_{}.png'.format(i))
    img = torch.autograd.Variable(img).cuda()
    predictions = model(img)
    predictions = predictions.data
    class_pred = predictions[:, :args.num_classes, :, :]
    bound_pred = predictions[:, args.num_classes:, :, :]
    for i in range(len(offset_list)):
        torchvision.utils.save_image(bound_pred[:, i:i + 1, :, :],
                                     'imgs/bound_pred{}.png'.format(i))
    for i in range(args.num_classes):
        torchvision.utils.save_image(class_pred[:, i:i + 1, :, :],
                                     'imgs/class_pred{}.png'.format(i))
Esempio n. 7
0
    # data info
    train_set = dataset_lits_faster.Lits_DataSet(args.crop_size,
                                                 args.batch_size,
                                                 args.resize_scale,
                                                 args.dataset_path,
                                                 mode='train')
    val_set = dataset_lits_faster.Lits_DataSet(args.crop_size,
                                               args.batch_size,
                                               args.resize_scale,
                                               args.dataset_path,
                                               mode='val')
    train_loader = DataLoader(dataset=train_set, shuffle=True)
    val_loader = DataLoader(dataset=val_set, shuffle=True)
    # model info
    model = UNet(1, [32, 48, 64, 96, 128],
                 3,
                 net_mode='3d',
                 conv_block=RecombinationBlock).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    init_util.print_network(model)
    # model = nn.DataParallel(model, device_ids=[0])   # multi-GPU

    log = logger.Logger('./output/{}'.format(args.save))

    best = [0, np.inf]  # 初始化最优模型的epoch和performance
    trigger = 0  # early stop 计数器
    for epoch in range(1, args.epochs + 1):
        common.adjust_learning_rate(optimizer, epoch, args)
        train_log = train(model, train_loader)
        val_log = val(model, val_loader)
        log.update(epoch, train_log, val_log)
Esempio n. 8
0
def main():
    global args
    args = parser.parse_args()
    args.batch_size = 1  # only segment one image for experiment

    core_config_path = os.path.join(args.dir, 'configs/core.config')
    unet_config_path = os.path.join(args.dir, 'configs/unet.config')

    core_config = CoreConfig()
    core_config.read(core_config_path)
    print('Using core configuration from {}'.format(core_config_path))

    # loading Unet configuration
    unet_config = UnetConfig()
    unet_config.read(unet_config_path, args.train_image_size)
    print('Using unet configuration from {}'.format(unet_config_path))

    offset_list = core_config.offsets
    print("offsets are: {}".format(offset_list))

    # model configurations from core config
    num_classes = core_config.num_classes
    num_colors = core_config.num_colors
    num_offsets = len(core_config.offsets)
    # model configurations from unet config
    start_filters = unet_config.start_filters
    up_mode = unet_config.up_mode
    merge_mode = unet_config.merge_mode
    depth = unet_config.depth

    model = UNet(num_classes,
                 num_offsets,
                 in_channels=num_colors,
                 depth=depth,
                 start_filts=start_filters,
                 up_mode=up_mode,
                 merge_mode=merge_mode)

    model_path = os.path.join(args.dir, args.model)
    if os.path.isfile(model_path):
        print("=> loading checkpoint '{}'".format(model_path))
        checkpoint = torch.load(model_path,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['state_dict'])
        print("loaded.")
    else:
        print("=> no checkpoint found at '{}'".format(model_path))

    model.eval()  # convert the model into evaluation mode

    testset = WaldoDataset(args.test_data, core_config, args.train_image_size)
    print('Total samples in the test set: {0}'.format(len(testset)))

    dataloader = torch.utils.data.DataLoader(testset,
                                             num_workers=1,
                                             batch_size=args.batch_size)

    segment_dir = '{}/segment'.format(args.dir)
    if not os.path.exists(segment_dir):
        os.makedirs(segment_dir)
    img, class_pred, adj_pred = sample(model, dataloader, segment_dir,
                                       core_config)

    seg = ObjectSegmenter(class_pred[0].detach().numpy(),
                          adj_pred[0].detach().numpy(), num_classes,
                          offset_list)
    mask_pred, object_class = seg.run_segmentation()
    x = {}
    # from (color, height, width) to (height, width, color)
    x['img'] = np.moveaxis(img[0].numpy(), 0, -1)
    x['mask'] = mask_pred.astype(int)
    x['object_class'] = object_class
    visualize_mask(x, core_config)
Esempio n. 9
0
                             args.dataset_path,
                             mode='train')
    val_set = Lits_DataSet(args.crop_size,
                           args.resize_scale,
                           args.dataset_path,
                           mode='val')
    train_loader = DataLoader(dataset=train_set,
                              batch_size=args.batch_size,
                              num_workers=1,
                              shuffle=True)
    val_loader = DataLoader(dataset=val_set,
                            batch_size=args.batch_size,
                            num_workers=1,
                            shuffle=True)
    # model info
    model = UNet(1, [16, 32, 48, 64, 96],
                 3,
                 net_mode='3d',
                 conv_block=RecombinationBlock).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    init_util.print_network(model)
    # model = nn.DataParallel(model, device_ids=[0,1])  # multi-GPU

    logger = logger.Logger('./output/{}'.format(args.save))
    for epoch in range(1, args.epochs + 1):
        common.adjust_learning_rate(optimizer, epoch, args)
        train(model, train_loader, optimizer, epoch, logger)
        val(model, val_loader, epoch, logger)
        torch.save(model, './output/{}/model.pth'.format(
            args.save))  # Save model with parameters
        # torch.save(model.state_dict(), './output/{}/model.pth'.format(args.save))  # Only save parameters
Esempio n. 10
0
def main():
    global args, best_loss
    args = parser.parse_args()

    if args.tensorboard:
        print("Using tensorboard")
        configure("exp/%s" % (args.name))

    if not (os.path.exists(args.train_data) and os.path.exists(args.train_data)
            and os.path.exists(args.test_data)):
        train, val, test = DataProcess(args.train_path, args.test_path, 0.9,
                                       args.img_channels)
        t.save(train, args.train_data)
        t.save(val, args.val_data)
        t.save(test, args.test_data)

    s_trans = tsf.Compose([
        tsf.ToPILImage(),
        tsf.Resize((args.img_height, args.img_width)),
        tsf.ToTensor(),
    ])

    t_trans = tsf.Compose([
        tsf.ToPILImage(),
        tsf.Resize((args.img_height, args.img_width),
                   interpolation=PIL.Image.NEAREST),
        tsf.ToTensor(),
    ])

    # split the training set into training set and validation set
    trainset = TrainDataset(args.train_data, s_trans, t_trans)
    trainloader = t.utils.data.DataLoader(trainset,
                                          num_workers=1,
                                          batch_size=args.batch_size,
                                          shuffle=True)

    valset = TrainDataset(args.val_data, s_trans, t_trans)
    valloader = t.utils.data.DataLoader(valset,
                                        num_workers=1,
                                        batch_size=args.batch_size)

    NUM_TRAIN = len(trainset)
    NUM_VAL = len(valset)
    NUM_ALL = NUM_TRAIN + NUM_VAL
    print(
        'Total samples: {0} \n'
        'Using {1} samples for training, '
        '{2} samples for validation'.format(NUM_ALL, NUM_TRAIN, NUM_VAL))

    testset = TestDataset(args.test_data, s_trans)
    testloader = t.utils.data.DataLoader(testset, num_workers=1, batch_size=1)

    # create model
    model = UNet(1, in_channels=3, depth=args.depth).cuda()

    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = t.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # define optimizer
    optimizer = t.optim.Adam(model.parameters(), lr=1e-3)

    # Train
    for epoch in range(args.start_epoch, args.epochs):
        Train(trainloader, model, optimizer, epoch)
        val_loss = Validate(valloader, model, epoch)
        is_best = val_loss < best_loss
        best_loss = min(val_loss, best_loss)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_loss,
            }, is_best)
    print 'Best validation loss: ', best_loss

    # Visualize some predicted masks on training data to get a better intuition
    # about the performance. Comment it if not necessary.
    datailer = iter(trainloader)
    img, mask = datailer.next()
    torchvision.utils.save_image(img, 'raw.png')
    torchvision.utils.save_image(mask, 'mask.png')
    img = t.autograd.Variable(img).cuda()
    img_pred = model(img)
    img_pred = img_pred.data
    torchvision.utils.save_image(img_pred > 0.5, 'predicted.png')

    # Load the best model and evaluate on test set
    checkpoint = t.load('exp/%s/' % (args.name) + 'model_best.pth.tar')
    model.load_state_dict(checkpoint['state_dict'])
    Test(testloader, model)
Esempio n. 11
0
import torch
from torchvision import transforms
from tqdm import tqdm
import utils
from torch.utils.data import DataLoader
from models.Unet import UNet
import os

# hyperparams
epoch = 4

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(n_channels=3, n_classes=4, bilinear=True)
print(model)
model.to(device)

optimizer = torch.optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

transform = transforms.Compose([
    utils.transforms.RandomMirror(),
    utils.transforms.ToTensor(),
    utils.transforms.Downsize(2)
])

dataset = utils.datasets.SteelDefectDataset(csv_file='train.csv',
root_dir='data/severstal-steel-defect-detection',transform=transform)
train_loader = DataLoader(dataset, batch_size=1,shuffle=True)

criterion = utils.loss.SegmentMSELoss()

for e in range(1,epoch+1):
Esempio n. 12
0
    return result_test, result_val, loss_val, loss_train


def init_normal(m):
    # try other initialisations
    if type(m) == nn.Linear:
        nn.init.xavier_uniform(m.weight)


if __name__ == '__main__':

    opts = get_args()

    # getting the structure of your model
    model = UNet(depth=opts.depth,
                 kernel_size=opts.kernel_size,
                 kernel_num=opts.kernel_num,
                 n_classes=3)

    if model.useCUDA:
        model.cuda()
    # summary(model, input_size=(1, 512, 256))
    # raise
    model.apply(init_normal)

    traindata = Healthy(mode='train', root=opts.inDir, opts=opts)
    trainset = DataLoader(traindata, shuffle=True, batch_size=opts.bs)
    print('train', traindata.__len__(), traindata.__getsize__())

    valdata = Healthy(mode='val', root=opts.inDir, opts=opts)
    valset = DataLoader(valdata, shuffle=True, batch_size=opts.bs)
    print('val', valdata.__len__(), valdata.__getsize__())
Esempio n. 13
0
    val_dice2 = metrics.dice(pred, target, 2)

    pred_img = torch.argmax(pred,dim=1)
    img = sitk.GetImageFromArray(np.squeeze(np.array(pred_img.numpy(),dtype='uint8'),axis=0))
    sitk.WriteImage(img, os.path.join(save_path, filename))

    # save_tool.save(filename)
    print('\nAverage loss: {:.4f}\tdice0: {:.4f}\tdice1: {:.4f}\tdice2: {:.4f}\t\n'.format(
        val_loss, val_dice0, val_dice1, val_dice2))
    return val_loss, val_dice0, val_dice1, val_dice2

if __name__ == '__main__':
    args = config.args
    device = torch.device('cpu' if args.cpu else 'cuda')
    # model info
    model = UNet(1, [32, 48, 64, 96, 128], 3, net_mode='3d',conv_block=RecombinationBlock).to(device)
    ckpt = torch.load('./output/{}/best_model.pth'.format(args.save))
    model.load_state_dict(ckpt['net'])

    # data info
    test_data_path = r'F:\datasets\LiTS\test'
    result_save_path = r'./output/{}/result'.format(args.save)
    if not os.path.exists(result_save_path): os.mkdir(result_save_path)
    cut_param = {'patch_s': 32,
                 'patch_h': 128,
                 'patch_w': 128,
                 'stride_s': 24,
                 'stride_h': 96,
                 'stride_w': 96}
    datasets = test_Datasets(test_data_path,cut_param,resize_scale=args.resize_scale)
    for dataset,file_idx in datasets: