示例#1
0
#-*- coding:utf-8 -*-

from data_extension import DataExtension
from unet import UNet
import cv2
if __name__ == '__main__':

    de = DataExtension()
    # 扩展数据
    de.data_extension()
    # 创建训练集
    de.create_train_data()
    # 加载训练集
    train_img, train_lbl = de.load_train_data()

    unet = UNet()
    #unet.unet_train(train_img, train_lbl)
    img = cv2.imread('./data/test/test.png')
    label = unet.unet_predict_img(img)
    cv2.imwrite('test_label.png', label)
示例#2
0
Date created: 11/3/2020
Date last modified: 11/24/2020
Python Version: 3
"""
import numpy as np
import matplotlib.pyplot as plt

from data import test_generator, test_data
from setting import *
from unet import UNet


"""
Do evaluation with result from train.py
"""
model = UNet((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
model.load_weights("unet.hdf5")
results = model.evaluate(test_generator,
                         verbose=1,
                         steps=len(test_data) / BATCH_SIZE,
                         return_dict=True)
print('Evaluation dice coefficient is', results['dice_coef'])


"""
plot image for samples
"""
batches = 2
plt.figure(figsize=(10, 10))
for i in range(batches):
    img, mask = next(test_generator)
示例#3
0
class TextRemoval(object):
    """Implementation of Noise2Noise from Lehtinen et al. (2018)."""

    def __init__(self, params, trainable):
        """Initializes model."""

        self.p = params
        self.trainable = trainable
        self._compile()


    def _compile(self):
        """Compiles model (architecture, loss function, optimizers, etc.)."""

        # Model (3x3=9 channels for Monte Carlo since it uses 3 HDR buffers)
        self.model = UNet(in_channels=3)

        # Set optimizer and loss, if in training mode
        if self.trainable:
            self.optim = Adam(self.model.parameters(),
                              lr=self.p.learning_rate,
                              betas=self.p.adam[:2],
                              eps=self.p.adam[2])

            # Learning rate adjustment
            self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optim,
                patience=self.p.nb_epochs/4, factor=0.5, verbose=True)

            # Loss function
            self.loss = nn.L1Loss()

        # CUDA support
        self.use_cuda = torch.cuda.is_available() and self.p.cuda
        if self.use_cuda:
            self.model = self.model.cuda()
            if self.trainable:
                self.loss = self.loss.cuda()


    def _print_params(self):
        """Formats parameters to print when training."""

        print('Training parameters: ')
        self.p.cuda = self.use_cuda
        param_dict = vars(self.p)
        pretty = lambda x: x.replace('_', ' ').capitalize()
        print('\n'.join('  {} = {}'.format(pretty(k), str(v)) for k, v in param_dict.items()))
        print()


    def save_model(self, epoch, stats, first=False):
        """Saves model to files; can be overwritten at every epoch to save disk space."""

        # Create directory for model checkpoints, if nonexistent
        if first:
            ckpt_dir_name = f'{datetime.now()}'
            if self.p.ckpt_overwrite:
                ckpt_dir_name = self.noise_type

            self.ckpt_dir = os.path.join(self.p.ckpt_save_path, ckpt_dir_name)
            if not os.path.isdir(self.p.ckpt_save_path):
                os.mkdir(self.p.ckpt_save_path)
            if not os.path.isdir(self.ckpt_dir):
                os.mkdir(self.ckpt_dir)

        # Save checkpoint dictionary
        if self.p.ckpt_overwrite:
            fname_unet = '{}/overwrite.pt'.format(self.ckpt_dir)
        else:
            valid_loss = stats['valid_loss'][epoch]
            fname_unet = '{}/epoch{}-{:>1.5f}.pt'.format(self.ckpt_dir, epoch + 1, valid_loss)
        print('Saving checkpoint to: {}\n'.format(fname_unet))
        torch.save(self.model.state_dict(), fname_unet)

        # Save stats to JSON
        fname_dict = '{}/stats.json'.format(self.ckpt_dir)
        with open(fname_dict, 'w') as fp:
            json.dump(stats, fp, indent=2)


    def load_model(self, ckpt_fname):
        """Loads model from checkpoint file."""

        print('Loading checkpoint from: {}'.format(ckpt_fname))
        if self.use_cuda:
            self.model.load_state_dict(torch.load(ckpt_fname))
        else:
            self.model.load_state_dict(torch.load(ckpt_fname, map_location='cpu'))


    def _on_epoch_end(self, stats, train_loss, epoch, epoch_start, valid_loader):
        """Tracks and saves starts after each epoch."""

        # Evaluate model on validation set
        print('\rTesting model on validation set... ', end='')
        epoch_time = time_elapsed_since(epoch_start)[0]
        valid_loss, valid_time, valid_psnr = self.eval(valid_loader)
        show_on_epoch_end(epoch_time, valid_time, valid_loss, valid_psnr)

        # Decrease learning rate if plateau
        self.scheduler.step(valid_loss)

        # Save checkpoint
        stats['train_loss'].append(train_loss)
        stats['valid_loss'].append(valid_loss)
        stats['valid_psnr'].append(valid_psnr)
        self.save_model(epoch, stats, epoch == 0)

        # Plot stats
        if self.p.plot_stats:
            loss_str = f'{self.p.loss.upper()} loss'
            plot_per_epoch(self.ckpt_dir, 'Valid loss', stats['valid_loss'], loss_str)
            plot_per_epoch(self.ckpt_dir, 'Valid PSNR', stats['valid_psnr'], 'PSNR (dB)')



    def eval(self, valid_loader):
        """Evaluates denoiser on validation set."""

        self.model.train(False)

        valid_start = datetime.now()
        loss_meter = AvgMeter()
        psnr_meter = AvgMeter()

        for batch_idx, (source, target) in enumerate(valid_loader):
            if self.use_cuda:
                source = source.cuda()
                target = target.cuda()

            # Denoise
            source_denoised = self.model(source)

            # Update loss
            loss = self.loss(source_denoised, target)
            loss_meter.update(loss.item())

            # Compute PSRN
            # TODO: Find a way to offload to GPU, and deal with uneven batch sizes
            for i in range(self.p.batch_size):
                source_denoised = source_denoised.cpu()
                target = target.cpu()
                psnr_meter.update(psnr(source_denoised[i], target[i]).item())

        valid_loss = loss_meter.avg
        valid_time = time_elapsed_since(valid_start)[0]
        psnr_avg = psnr_meter.avg

        return valid_loss, valid_time, psnr_avg


    def train(self, train_loader, valid_loader):
        """Trains denoiser on training set."""

        self.model.train(True)

        self._print_params()
        num_batches = len(train_loader)
        assert num_batches % self.p.report_interval == 0, 'Report interval must divide total number of batches'

        # Dictionaries of tracked stats
        stats = {'noise_param': self.p.noise_param,
                 'train_loss': [],
                 'valid_loss': [],
                 'valid_psnr': []}

        # Main training loop
        train_start = datetime.now()
        for epoch in range(self.p.nb_epochs):
            print('EPOCH {:d} / {:d}'.format(epoch + 1, self.p.nb_epochs))

            # Some stats trackers
            epoch_start = datetime.now()
            train_loss_meter = AvgMeter()
            loss_meter = AvgMeter()
            time_meter = AvgMeter()

            # Minibatch SGD
            for batch_idx, (source, target) in enumerate(train_loader):
                batch_start = datetime.now()
                progress_bar(batch_idx, num_batches, self.p.report_interval, loss_meter.val)

                if self.use_cuda:
                    source = source.cuda()
                    target = target.cuda()

                # Denoise image
                source_denoised = self.model(source)

                loss = self.loss(source_denoised, target)
                loss_meter.update(loss.item())

                # Zero gradients, perform a backward pass, and update the weights
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()

                # Report/update statistics
                time_meter.update(time_elapsed_since(batch_start)[1])
                if (batch_idx + 1) % self.p.report_interval == 0 and batch_idx:
                    show_on_report(batch_idx, num_batches, loss_meter.avg, time_meter.avg)
                    train_loss_meter.update(loss_meter.avg)
                    loss_meter.reset()
                    time_meter.reset()

            # Epoch end, save and reset tracker
            self._on_epoch_end(stats, train_loss_meter.avg, epoch, epoch_start, valid_loader)
            train_loss_meter.reset()

        train_elapsed = time_elapsed_since(train_start)[0]
        print('Training done! Total elapsed time: {}\n'.format(train_elapsed))
示例#4
0

if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO,
                        format='%(levelname)s: %(message)s')
    args = get_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    # Change here to adapt to your data
    # n_channels=3 for RGB images
    # n_classes is the number of probabilities you want to get per pixel
    #   - For 1 class and background, use n_classes=1
    #   - For 2 classes, use n_classes=1
    #   - For N > 2 classes, use n_classes=N
    net = UNet(n_channels=1, n_classes=args.n_classes, bilinear=True)
    logging.info(
        f'Network:\n'
        f'\t{net.n_channels} input channels\n'
        f'\t{net.n_classes} output channels (classes)\n'
        f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

    if args.load:
        net.load_state_dict(torch.load(args.load, map_location=device))
        logging.info(f'Model loaded from {args.load}')

    net.to(device=device)
    # faster convolutions, but more memory
    # cudnn.benchmark = True

    try:
示例#5
0
def train():
    transforms = [
        Transforms.RondomFlip(),
        Transforms.RandomRotate(15),
        Transforms.Log(0.5),
        Transforms.Blur(0.2),
        Transforms.ToGray(),
        Transforms.ToTensor()
    ]
    train_dataset = UNetDataset('./data/train/',
                                './data/train_cleaned/',
                                transform=transforms)
    train_dataLoader = DataLoader(dataset=train_dataset,
                                  batch_size=config.BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=0)

    valid_dataset = UNetDataset('./data/valid/',
                                './data/valid_cleaned/',
                                transform=transforms)
    valid_dataLoader = DataLoader(dataset=valid_dataset,
                                  batch_size=config.BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=0)

    net = UNet(n_channels=config.n_channels,
               n_classes=config.n_classes).to(config.device)
    writer = SummaryWriter()
    optimizer = optim.Adam(net.parameters(), lr=config.LR)
    if config.n_classes > 1:
        loss_func = nn.CrossEntropyLoss().to(config.device)
    else:
        loss_func = nn.BCEWithLogitsLoss().to(config.device)
    best_loss = float('inf')

    if os.path.exists(config.weight_with_optimizer):
        checkpoint = torch.load(config.weight_with_optimizer,
                                map_location='cpu')
        net.load_state_dict(checkpoint['net'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print('load weight')

    for epoch in range(config.EPOCH):
        train_loss = 0
        net.train()
        for step, (batch_x, batch_y) in enumerate(train_dataLoader):
            batch_x = batch_x.to(device=config.device)
            batch_y = batch_y.squeeze(1).to(device=config.device)
            output = net(batch_x)
            loss = loss_func(output, batch_y)
            train_loss += loss.item()
            if loss < best_loss:
                best_loss = loss
                torch.save(
                    {
                        'net': net.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }, config.best_model_with_optimizer)
                torch.save({'net': net.state_dict()}, config.best_model)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        net.eval()
        eval_loss = 0
        for step, (batch_x, batch_y) in enumerate(valid_dataLoader):
            batch_x = batch_x.to(device=config.device)
            batch_y = batch_y.squeeze(1).to(device=config.device)
            output = net(batch_x)
            valid_loss = loss_func(output, batch_y)
            eval_loss += valid_loss.item()

        writer.add_scalar("train_loss", train_loss, epoch)
        writer.add_scalar("eval_loss", eval_loss, epoch)
        print("*" * 80)
        print('epoch: %d | train loss: %.4f | valid loss: %.4f' %
              (epoch, train_loss, eval_loss))
        print("*" * 80)

        if (epoch + 1) % 10 == 0:
            torch.save(
                {
                    'net': net.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, config.weight_with_optimizer)
            torch.save({'net': net.state_dict()}, config.weight)
            print('saved')

    writer.close()
示例#6
0
def train_net(load_pth,
              device,
              epochs=5,
              batch_size=1,
              lr=0.001,
              val_percent=0.1,
              save_cp=True,
              img_scale=0.5):

    # data load for Carvana
    # dataset = BasicDataset(dir_img, dir_mask, img_scale)
    # n_val = int(len(dataset) * val_percent)
    # n_train = len(dataset) - n_val
    # train, val = random_split(dataset, [n_train, n_val])
    # train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    # val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)

    # data load for VOC2007
    t_loader = data_loader(voc_data_path,
                           sbd_path,
                           is_transform=True,
                           split='train_aug',
                           img_size=(500, 500))
    v_loader = data_loader(voc_data_path,
                           sbd_path,
                           is_transform=True,
                           split='val',
                           img_size=(500, 500))

    trainloader = DataLoader(t_loader,
                             batch_size=2,
                             num_workers=16,
                             shuffle=True)
    valloader = DataLoader(v_loader, batch_size=2, num_workers=16)

    # init model
    net = UNet(n_channels=3, n_classes=t_loader.n_classes, bilinear=True)
    logging.info(
        f'Network:\n'
        f'\t{net.n_channels} input channels\n'
        f'\t{net.n_classes} output channels (classes)\n'
        f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
    if load_pth:
        net.load_state_dict(torch.load(load_pth, map_location=device))
        logging.info(f'Model loaded from {load_pth}')
    net = net.to(device=device)
    net = torch.nn.DataParallel(net,
                                device_ids=range(torch.cuda.device_count()))

    writer = SummaryWriter(
        comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
    global_step = 0

    n_train = len(t_loader)
    n_val = len(v_loader)

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Images scaling:  {img_scale}
    ''')

    optimizer = optim.SGD(net.parameters(),
                          lr=1.0e-10,
                          weight_decay=0.0005,
                          momentum=0.99)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     'min',
                                                     patience=2)

    loss_fc = get_loss_function('cross_entropy', size_average=False)

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in trainloader:
                imgs = batch['image']
                true_masks = batch['mask']

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)

                masks_pred = net(imgs)
                loss = loss_fc(input=masks_pred, target=true_masks)
                epoch_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1
                if global_step % (n_train // (10 * batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag,
                                             value.data.cpu().numpy(),
                                             global_step)
                        writer.add_histogram('grads/' + tag,
                                             value.grad.data.cpu().numpy(),
                                             global_step)
                    val_score = eval_net(net, valloader, loss_fc, device)
                    scheduler.step(val_score)
                    writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'],
                                      global_step)

                    logging.info(
                        'Validation cross entropy: {}'.format(val_score))
                    writer.add_scalar('Loss/test', val_score, global_step)

                    writer.add_images('images', imgs, global_step)

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(),
                       dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()
        out_files = args.output

    return out_files


def mask_to_image(mask):
    print(mask[0].shape)
    return Image.fromarray((mask[0] * 255).astype(np.uint8))


if __name__ == "__main__":
    args = get_args()
    in_files = args.input
    out_files = get_output_filenames(args)

    net = UNet(n_channels=3, n_classes=14)

    logging.info("Loading model {}".format(args.model))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')
    net.to(device=device)
    net.load_state_dict(torch.load(args.model, map_location=device))

    logging.info("Model loaded !")

    for i, fn in enumerate(in_files):
        logging.info("\nPredicting image {} ...".format(fn))

        img = Image.open(fn)
import numpy as np
import matplotlib.pyplot as plt
import random
from loader import SnakeGameDataset as dataset
from unet import UNet
import torch
from torch.autograd import Variable
from torchvision import datasets, transforms

import matplotlib.pyplot as plt

code_size = 120

autoencoder = UNet(1, in_channels=1, depth=5, start_filts=8, merge_mode='add')
autoencoder.load_state_dict(
    torch.load("/home/joshua/Desktop/ConvAutoencoder/unet_model.pth"))
autoencoder.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
autoencoder.to(device)

test_data = dataset(
    root_dir="/home/joshua/Desktop/ConvAutoencoder/data/testData",
    transform=None,
    prefix="testDataBoards1.npy")

vis_1, vis_2, vis_3 = random.choice(test_data), random.choice(
    test_data), random.choice(test_data)
check_1 = torch.from_numpy(vis_1[0]).unsqueeze(0).unsqueeze(1).float()
check_2 = torch.from_numpy(vis_2[0]).unsqueeze(0).unsqueeze(1).float()
check_3 = torch.from_numpy(vis_3[0]).unsqueeze(0).unsqueeze(1).float()
res_1, res_2, res_3 = autoencoder(Variable(check_1.to(device))), autoencoder(
示例#9
0
def main():
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    # Training settings
    parser = argparse.ArgumentParser(
        description='Scratch segmentation Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=8,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=8,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=200,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=20,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--optim', type=str, default='sgd', help="optimizer")
    parser.add_argument('--lr-scheduler',
                        type=str,
                        default='cos',
                        choices=['poly', 'step', 'cos'],
                        help='lr scheduler mode: (default: poly)')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")
    print('my device is :', device)

    kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(FarmDataset(istrain=True),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               drop_last=True,
                                               **kwargs)
    #
    startepoch = 0
    model = torch.load(
        './tmp/model{}'.format(startepoch)) if startepoch else UNet(3,
                                                                    4).cuda()

    # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

    # Define Optimizer
    train_params = model.parameters()
    weight_decay = 5e-4
    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(train_params,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=weight_decay,
                                    nesterov=False)
    elif args.optim == 'adam':
        optimizer = torch.optim.Adam(train_params,
                                     lr=args.lr,
                                     weight_decay=weight_decay,
                                     nesterov=False)
    else:
        raise NotImplementedError("Optimizer have't been implemented")

    scheduler = LR_Scheduler(args.lr_scheduler,
                             args.lr,
                             args.epochs,
                             157,
                             lr_step=10,
                             warmup_epochs=10)

    for epoch in range(startepoch, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch, scheduler)
        if epoch > 50:
            torch.save(model, './tmp/model{}'.format(epoch))
        elif epoch % 10 == 0:
            torch.save(model, './tmp/model{}'.format(epoch))
import cv2
import numpy as np
import torch
import os
from unet import UNet

weight = './weight/weight.pth'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# load net
print('load net')
net = UNet(1, 2).to(device)
if os.path.exists(weight):
    checkpoint = torch.load(weight)
    net.load_state_dict(checkpoint['net'])
else:
    exit(0)

# load img
print('load img')
dir = './data/test/'
filenames = [
    os.path.join(dir, filename) for filename in os.listdir(dir)
    if filename.endswith('.jpg') or filename.endswith('.png')
]
totalN = len(filenames)

for index, filename in enumerate(filenames):
    img = cv2.imread(filename, 0)
    if img is None:
        print('img is None')
示例#11
0
def pro(path):
    n_class = 3
    model = UNet(n_channels=3, n_classes=n_class)
    model = nn.DataParallel(model, device_ids=[0])
    model.load_state_dict(torch.load('trainmodels.pth'))
    tensor2pil = transforms.ToPILImage('RGB')
    img2tensor = transforms.ToTensor()
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.3301, 0.3301, 0.3301],
                             std=[0.1938, 0.1938, 0.1938])
    ])
    image = Image.open(path).convert('RGB')
    image = image.resize((224, 224), Image.ANTIALIAS)
    in_img_tensor = transform(image)
    in_img_origin = img2tensor(image)
    n = in_img_tensor.size()
    data = torch.ones(1, 3, n[1], n[2])
    data[0] = in_img_tensor
    in_img_origin = torch.squeeze(in_img_origin)  # 保存原图
    data = Variable(data.cuda(0))
    with torch.no_grad():
        output = model(data)  # 把图片输入模型处理
    out = torch.reshape(output.cpu(), (n_class, n[1], n[2]))
    pre = torch.max(out, 0)[1]
    out_img_tensor = torch.zeros(3, n[1], n[2])
    out_img_tensor[0, pre == 2] = 1
    out_img_tensor[1, pre == 1] = 1
    ratio = torch.sum(out_img_tensor[0]) / torch.sum(out_img_tensor)
    if ratio > 0.3:
        # out_img_tensor[0] += out_img_tensor[1]
        # out_img_tensor[1] = 0
        results = '恶性'
    else:
        # out_img_tensor[1] += out_img_tensor[0]
        # out_img_tensor[0] = 0
        results = '良性'  # 判断良恶性
    out_img_ndarray = cv.cvtColor(np.asarray(tensor2pil(out_img_tensor)),
                                  cv.COLOR_RGB2BGR)
    image_cv = cv.cvtColor(np.asarray(image), cv.COLOR_RGB2BGR)
    out_img_ndarray = contours(out_img_ndarray, image_cv, results)

    #  =========================================
    # for i in range(n[1]):
    #     for j in range(n[2]):
    #         if out_img_tensor[0, i, j] == 1 or out_img_tensor[1, i, j] == 1:
    #             in_img_origin[:, i, j] = 0
    #
    #
    # out_img_tensor = in_img_origin + out_img_tensor
    # ==========================================
    path_filename = osp.split(path)
    f_name = os.path.splitext(path_filename[1])[0]
    filename = path_filename[1]  # [0]表示file的路径, [1]表示文件
    out_img_pil = Image.fromarray(
        cv.cvtColor(out_img_ndarray, cv.COLOR_BGR2RGB))
    fp = path_filename[0] + '/res/'
    if not os.path.exists(fp):
        os.mkdir(fp)
    fp1 = fp + f_name + '.png'
    fp2 = fp + 'tmp.png'
    out_img_pil.save(fp1)
    image.save(fp + 'tmp.png')
    return fp1, fp2, results
示例#12
0
def main():
    model = nn.DataParallel(UNet(n_channels=4, n_classes=2, bilinear=True))
    model.cuda()
    model.load_state_dict(torch.load('best_model.pt'))
    model.eval()
    count = 0
    for batch_idx, data in enumerate(test_loader):
        img = data[0].cuda()
        mask = data[1].cuda()
        part_mask = data[2].cuda()
        depth = data[3].cuda().float()
        img_name = data[4][0]

        input_data = torch.cat([img, depth], dim=1)
        #print(input_data.shape)
        #img = img * part_mask

        cls = model(input_data).squeeze(dim=0)
        #print(cls.shape)

        cls_b = cls.cpu().detach().numpy()
        #img = img.squeeze(dim=0).cpu().detach().numpy()
        #cls_b = np.transpose(cls_b, (1,2,0))

        #img = np.transpose(img, (2,1,0))
        nameload = '../SIM_dataset_v10/rgb/' + test_dir[batch_idx]
        real_img = cv2.imread(nameload)
        real_img = cv2.cvtColor(real_img, cv2.COLOR_BGR2RGB)
        img_h, img_w, img_c = real_img.shape

        #print('W:',img_w)
        #print('H:',img_h)
        #real_img = cv2.warpAffine(real_img, cv2.getRotationMatrix2D((int(img_shape[1] / 2), int(img_shape[0]/ 2)), 180, 1),(img_shape[1], img_shape[0]))

        #nf_array = np.zeros(320,200)
        # print(cls_b.shape)
        # exit()
        #cv2.imwrite('report/' + img_name[0]+'.png', img)
        for w in range(320):
            for h in range(200):
                #print(cls_b[h][w][0])
                #print(cls_b[h][w][1])
                if cls_b[0][w][h] < cls_b[1][w][h]:
                    print('yes')
                    #x,y = Transpose(w,h,0.625)
                    cv2.circle(real_img,
                               (int(w * img_w / 320), int(h * img_h / 200)),
                               13, [250, 0, 0], -1)
                    #cv2.circle(real_img, (int(w*w_img/320), int(h)), 13, [150,0,0], -1)
        #cv2.circle(real_img, (0,0), 15, [0,0,150], -1)
        #cv2.circle(real_img, (320,200), 15, [0,0,150],-1)
        #cv2.circle(real_img, (640,400), 15, [0,0,150], -1)
        #cv2.circle(real_img, (960, 600), 15, [0,0,150], -1)
        #cv2.circle(real_img, (1280, 800), 15, [0,0,150], -1)
        #cv2.circle(real_img, (1600, 1000), 15, [0,0,150], -1)
        #cv2.circle(real_img, (1920, 1200), 15, [0,0,150], -1)
        #real_img[3*h][3*w][2]=255
        # else :
        #     img[3*h][3*w][0] = 0
        #     img[3 * h][3 * w][1] = 0
        #     img[3 * h][3 * w][2] = 0
        st_name = img_name.split('/')
        cv2.imwrite('report/' + st_name[len(st_name) - 1], real_img)
示例#13
0
class Trainer(object):
    def __init__(self, config, args):
        self.args = args
        self.config = config
        # Define Dataloader
        self.train_loader, self.val_loader, self.test_loader = make_data_loader(
            config)

        # Define network
        #self.model = DeepLab(num_classes=self.nclass,
        #                backbone=config.backbone,
        #                output_stride=config.out_stride,
        #                sync_bn=config.sync_bn,
        #                freeze_bn=config.freeze_bn)
        self.model = UNet(n_channels=1, n_classes=3, bilinear=True)

        #train_params = [{'params': self.model.get_1x_lr_params(), 'lr': config.lr},
        #                {'params': self.model.get_10x_lr_params(), 'lr': config.lr * config.lr_ratio}]

        # Define Optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=config.lr,
                                         momentum=config.momentum,
                                         weight_decay=config.weight_decay)

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = MSELoss(cuda=args.cuda)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(config.lr_scheduler,
                                      config.lr, config.epochs,
                                      len(self.train_loader), config.lr_step,
                                      config.warmup_epochs)
        self.summary = TensorboardSummary('./train_log')

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            # cudnn.benchmark = True
            self.model = self.model.cuda()

        self.best_pred_source = 0.0
        # Resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint,
                                           map_location=torch.device('cpu'))
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch))

    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        if config.freeze_bn:
            self.model.module.freeze_bn()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            itr = epoch * len(self.train_loader) + i
            self.summary.writer.add_scalar(
                'Train/lr', self.optimizer.param_groups[0]['lr'], itr)
            A_image, A_target = sample['image'], sample['label']

            if self.args.cuda:
                A_image, A_target = A_image.cuda(), A_target.cuda()

            self.scheduler(self.optimizer, i, epoch, self.best_pred_source, 0.,
                           self.config.lr_ratio)

            A_output = self.model(A_image)

            self.optimizer.zero_grad()

            # Train seg network
            # Supervised loss
            #seg_loss = self.criterion(A_output, A_target)
            main_loss = self.criterion(A_output, A_target)

            main_loss.backward()

            self.optimizer.step()

            train_loss += main_loss.item()
            self.summary.writer.add_scalar('Train/MSELoss', main_loss.item(),
                                           itr)
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))

            # Show the results of the last iteration
            #if i == len(self.train_loader)-1:
        print("Add Train images at epoch" + str(epoch))
        self.summary.visualize_image('Train', self.config.dataset, A_image,
                                     A_target, A_output, epoch, 5)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.config.batch_size + A_image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

    def validation(self, epoch):
        def get_metrics(tbar):
            test_loss = 0.0
            for i, sample in enumerate(tbar):
                image, target = sample['image'], sample['label']

                if self.args.cuda:
                    image, target = image.cuda(), target.cuda()

                with torch.no_grad():
                    output = self.model(image)

                loss = self.criterion(output, target)
                test_loss += loss.item()
                tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
                pred = output.data.cpu().numpy()

                target_ = target.cpu().numpy()
            print("Add Validation-Source images at epoch" + str(epoch))
            self.summary.visualize_image('Val', self.config.dataset, image,
                                         target, output, epoch, 5)
            print('Loss: %.3f' % test_loss)

            return test_loss

        self.model.eval()
        tbar = tqdm(self.val_loader, desc='\r')
        err = get_metrics(tbar)

        new_pred_source = err

        if new_pred_source > self.best_pred_source:
            is_best = True
            self.best_pred_source = max(new_pred_source, self.best_pred_source)
        print('Saving state, epoch:', epoch)
        torch.save(
            self.model.module.state_dict(),
            self.args.save_folder + 'models/' + 'epoch' + str(epoch) + '.pth')
        loss_file = {'err': err}
        with open(
                os.path.join(self.args.save_folder, 'eval',
                             'epoch' + str(epoch) + '.json'), 'w') as f:
            json.dump(loss_file, f)
示例#14
0
# How many phases did we create databases for?
phases = ["train","val"] 
# When should we do valiation? note that validation is time consuming, so as opposed to doing 
# for both training and validation, we do it only for vlaidation at the end of the epoch
validation_phases= ["val"] 

# Specify if we should use a GPU (cuda) or only the CPU
# The Totch device is where we allocate and manipulate a tensor. This is either
# the local CPU or a GPU.
print(torch.cuda.get_device_properties(gpuid))
torch.cuda.set_device(gpuid)
torchDevice = torch.device('cuda:'+str(gpuid) if torch.cuda.is_available() else 'cpu')
#torchDevice = torch.device('cpu')

# Build the UNetRuntime according to the paramters specified above and copy it to the GPU. 
UNetRuntime = UNet(n_classes=n_classes, in_channels=in_channels, padding=padding, depth=depth, wf=wf, up_mode=up_mode, batch_norm=batch_norm).to(torchDevice)
# Print out the number of trainable parameters
params = sum([np.prod(p.size()) for p in UNetRuntime.parameters()])
print("total params: \t"+str(params))


################################################################################
# helper function for pretty printing of current time and remaining time
################################################################################

########################################
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)
示例#15
0
 def make_generator(self, generator_args: Dict):
     return UNet(**generator_args)
示例#16
0
def seg_wsi(args):
    def normalize(data):
        return data / data.max()

    sampled_img_size = 256
    # the border is single-side, e.g. [512,512] -> [528, 528]
    border = 16

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    K.set_session(sess)
    # K.set_learning_phase(0)

    img = tf.placeholder(tf.float32,
                         shape=(None, args.imSize, args.imSize,
                                args.input_channel))
    model = UNet().create_model(img_shape=[opt.imSize, opt.imSize, 3],
                                num_class=opt.num_class,
                                rate=0.0,
                                input_tensor=preprocess_input(img))
    unet_pred = tf.nn.softmax(model.output)

    sess.run(tf.global_variables_initializer())  # initilize

    try:
        model.load_weights(args.load_from_checkpoint)
        print("[*] Success to load model.")
    except:
        sys.exit("[*] Failed to find a checkpoint " +
                 args.load_from_checkpoint)

    print("=> Starting WSI patch generation...")

    wsi_filelist = []
    wsi_filelist.extend(sorted(glob.glob(os.path.join(args.wsi_dir, '*.svs'))))
    wsi_filelist.extend(sorted(glob.glob(os.path.join(args.wsi_dir,
                                                      '*.tiff'))))
    # wsi_filelist = ['../../dataset/bladder/test_slides_small/104842_sub1_type1.tiff']

    segmented_files = next(os.walk(args.res_dir))[1]

    SlideHandler = SlideLoader(args.batch_size,
                               level=args.slide_level,
                               to_real_scale=4,
                               imsize=args.imSize)
    print("=> Found {} whole slide images in total.".format(len(wsi_filelist)))
    print("=> {} has been processed.".format(len(segmented_files)))
    wsi_filelist = [
        a for a in wsi_filelist
        if os.path.splitext(os.path.basename(a))[0] not in segmented_files
    ]
    print("=> {} is being processed.".format(len(wsi_filelist)))
    end = min(args.end, len(wsi_filelist))
    for index in range(args.start, end):  # TODO remove this s
        wsi_filepath = wsi_filelist[index]
        wsi_img_name = os.path.splitext(os.path.basename(wsi_filepath))[0]

        if os.path.isdir(os.path.join(args.res_dir, wsi_img_name)) or \
            wsi_img_name in segmented_files:
            continue

        start_time = datetime.now()
        print("=> Start {}/{} segment {}".format(index + 1, end, wsi_img_name))

        # if wsi_img_name in ClearMarginPos.keys():
        #     crop_down_scale = int(ClearMarginPos[wsi_img_name][3])
        # else:
        crop_down_scale = 1

        try:
            slide_iterator, num_batches, slide_name, act_slide_size = SlideHandler.get_slide_iterator(
                wsi_filepath, down_scale_rate=crop_down_scale, overlapp=512)
            wsi_seg_results = np.zeros(act_slide_size, dtype=np.float16)
            wsi_img_results = np.zeros(act_slide_size + [3], dtype=np.uint8)
            wsi_mask = np.zeros(
                act_slide_size, dtype=np.float16
            )  # used to average the overlapping region for border removing
            candidates = []
            with sess.as_default():
                for step, (batch_imgs, locs) in enumerate(slide_iterator):
                    # locs[0]: (y, x)

                    sys.stdout.write('{}-{},'.format(step,
                                                     (batch_imgs.shape[0])))
                    sys.stdout.flush()
                    feed_dict = {img: batch_imgs, K.learning_phase(): False}
                    batch_pred = sess.run(unet_pred, feed_dict=feed_dict)
                    batch_logits = batch_pred[:, :, :, 1]
                    # put the results back to
                    for id, (seg, im, loc) in enumerate(
                            zip(batch_logits, batch_imgs, locs)):
                        y, x = loc[0], loc[1]
                        # there is overlapping
                        seg_h, seg_w = seg.shape
                        # prevent overflow, not happen useually
                        if seg_h + y > wsi_seg_results.shape[0]:
                            y = wsi_seg_results.shape[0] - seg_h
                        if seg_w + x > wsi_seg_results.shape[1]:
                            x = wsi_seg_results.shape[1] - seg_w
                        wsi_mask[y:y + seg_h,
                                 x:x + seg_w] = wsi_mask[y:y + seg_h,
                                                         x:x + seg_w] + 1

                        ## gradient average
                        # diff_mask = wsi_mask[y:y+seg_h, x:x+seg_w].copy()
                        # diff_mask[diff_mask < 2] = 0
                        # wsi_seg_results[y:y+seg_h, x:x+seg_w] = gradient_merge(wsi_seg_results[y:y+seg_h, x:x+seg_w], seg.astype(np.float16), diff_mask)

                        ## simple average
                        # wsi_seg_results[y:y+seg_h, x:x+seg_w] = (wsi_seg_results[y:y+seg_h, x:x+seg_w] + seg.astype(np.float16)) / wsi_mask[y:y+seg_h, x:x+seg_w]

                        ## maximum
                        wsi_seg_results[y:y + seg_h, x:x + seg_w] = np.maximum(
                            wsi_seg_results[y:y + seg_h, x:x + seg_w],
                            seg.astype(np.float16))

                        wsi_img_results[y:y + seg_h,
                                        x:x + seg_w] = im.astype(np.uint8)
                        candidates.append([(y, x), seg.copy(), im.copy()])

            # Saving segmentation and sampling results
            cur_dir = os.path.join(args.res_dir,
                                   os.path.splitext(wsi_img_name)[0])
            if not os.path.exists(cur_dir):
                # shutil.rmtree(cur_dir)
                os.makedirs(cur_dir)

            # Sample ROI randomly
            sample_patches = patch_sampling(
                candidates,
                tot_samples=args.num_samples,
                stride_ratio=0.01,
                sample_size=[sampled_img_size, sampled_img_size],
                threshold=args.threshold)

            for idx in range(len(sample_patches)):
                file_name_surfix = '_' + str(idx).zfill(5) + '_' + str(sample_patches[idx][0][0]).zfill(6) + \
                                '_' + str(sample_patches[idx][0][1]).zfill(6)
                cur_patch_path = os.path.join(cur_dir,
                                              wsi_img_name + file_name_surfix)
                sample_img = (sample_patches[idx][1]).astype(np.uint8)
                sample_seg = (normalize(sample_patches[idx][2]) *
                              255.0).astype(np.uint8)
                misc.imsave(cur_patch_path + '.png',
                            sample_img)  # only target images are png format

            locs = [a[0] for a in sample_patches]
            visualize_heatmap((wsi_seg_results * 255.0).astype(np.uint8),
                              shape=wsi_img_results.shape,
                              stride=sampled_img_size,
                              wsi_img=wsi_img_results,
                              save_path=os.path.join(cur_dir, wsi_img_name))

            # thumb_img = gen_thumbnail(wsi_img_results, thumb_max=np.max(wsi_seg_results.shape))
            misc.imsave(os.path.join(cur_dir, wsi_img_name + '_seg.jpg'),
                        misc.imresize(wsi_seg_results, 0.5).astype(np.uint8))
            misc.imsave(os.path.join(cur_dir, wsi_img_name + '_thumb.jpg'),
                        misc.imresize(wsi_img_results, 0.5).astype(np.uint8))

            wsi_img_point_results = visualize_sampling_points(wsi_img_results,
                                                              locs,
                                                              path=None)
            misc.imsave(
                os.path.join(cur_dir, wsi_img_name + '_samplepoint.jpg'),
                misc.imresize(wsi_img_point_results, 0.5).astype(np.uint8))
            elapsed_time = datetime.now() - start_time
            print('=> Time {}'.format(elapsed_time))
        except Exception as e:
            print(e)
            pass
示例#17
0
        raise SystemExit()
    else:
        out_files = output

    return out_files


def mask_to_image(mask):
    return Image.fromarray((mask * 255).astype(np.uint8))


if __name__ == "__main__":
    in_files = pre_config.pre_img
    out_files = get_output_filenames(pre_config.pre_img, pre_config.output)

    net = UNet(n_channels=pre_config.channels, n_classes=1)

    logging.info("Loading model {}".format(pre_config.model))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')
    net.to(device=device)
    net.load_state_dict(torch.load(pre_config.model, map_location=device))

    logging.info("Model loaded !")

    for i, fn in enumerate(in_files):
        logging.info("\nPredicting image {} ...".format(fn))

        img = Image.open(fn)
        if 'RIGHT' in fn:
示例#18
0

def save_checkpoint(model, save_path):
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))

    torch.save(model.state_dict(), save_path)


if single_gpu_flag(opt):
    board = SummaryWriter(os.path.join('runs', opt.name))

prev_model = create_model(opt)
prev_model.cuda()

model = UNet(n_channels=4, n_classes=3)
if opt.distributed:
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.apply(weights_init('kaiming'))
model.cuda()

if opt.use_gan:
    discriminator = Discriminator()
    discriminator.apply(utils.weights_init('gaussian'))
    discriminator.cuda()

    acc_discriminator = AccDiscriminator()
    acc_discriminator.apply(utils.weights_init('gaussian'))
    acc_discriminator.cuda()

if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
示例#19
0
                        help='filenames of input images')
    parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
                        help='filenames of ouput images')
    parser.add_argument('--cpu', '-c', action='store_true',
                        help="Do not use the cuda version of the net",
                        default=False)
    parser.add_argument('--viz', '-v', action='store_true',
                        help="Visualize the images as they are processed",
                        default=False)
    parser.add_argument('--no-save', '-n', action='store_false',
                        help="Do not save the output masks",
                        default=False)

    args = parser.parse_args()
    print("Using model file : {}".format(args.model))
    net = UNet(3, 3)
    if not args.cpu:
        print("Using CUDA version of the net, prepare your GPU !")
        net.cuda()
    else:
        net.cpu()
        print("Using CPU version of the net, this may be very slow")


    #in_files = args.input
    in_files = open('data/test.txt').readlines()
    out_files = []
    save_path = 'experiment/up1/predict/scnn_xaiver_out'
    if not args.output:
        for f in in_files:
            pathsplit = os.path.splitext(f)
示例#20
0
def main(args):
  # parse args
  best_acc1 = 0.0
  # tensorboard writer
  writer = SummaryWriter(args.experiment + "/logs")

  if args.gpu >= 0:
    print("Use GPU: {}".format(args.gpu))
  else:
    print('Using CPU for computing!')

  fixed_random_seed = 2019
  torch.manual_seed(fixed_random_seed)
  np.random.seed(fixed_random_seed)
  random.seed(fixed_random_seed)


  # set up transforms for data augmentation
  mn = [float(x) for x in args.mean] if(args.mean) else [0.485, 0.456, 0.406] 
  st = [float(x) for x in args.std] if(args.std) else [0.229, 0.224, 0.225] 

  normalize = transforms.Normalize(mean=mn, std=st)
  train_transforms = get_train_transforms(normalize)
  val_transforms = get_val_transforms(normalize)
  if(args.train_denoiser):
    normalize = transforms.Normalize(mean=mn, std=st)
    train_transforms = get_denoiser_train_transforms(normalize)
    val_transforms = get_denoiser_val_transforms(normalize)
  elif(args.cub_training):
    networks.CLASSES=200
    normalize = transforms.Normalize(mean=mn, std=st)
    train_transforms = get_cub_train_transforms(normalize)
    val_transforms = get_cub_val_transforms(normalize)
  if(args.spad_training):
    networks.CLASSES=122
    normalize = transforms.Normalize(mean=mn, std=st)
    train_transforms = get_spad_train_transforms(normalize)
    val_transforms = get_spad_val_transforms(normalize)
  elif(args.cars_training):
    networks.CLASSES=196
    normalize = transforms.Normalize(mean=mn, std=st)
    train_transforms = get_cub_train_transforms(normalize)
    val_transforms = get_cub_val_transforms(normalize)
  elif(args.imagenet_training):
    networks.CLASSES=1000
    normalize = transforms.Normalize(mean=mn, std=st)
    train_transforms = get_imagenet_train_transforms(normalize)
    val_transforms = get_imagenet_val_transforms(normalize)
  if (not args.evaluate):
    print("Training time data augmentations:")
    print(train_transforms)


  model_clean=None
  model_teacher=None
  if args.use_resnet18:
    model = torchvision.models.resnet18(pretrained=False)
    model.fc = nn.Linear(512, networks.CLASSES)
    if(args.use_resnet18!="random"):
        model.load_state_dict(torch.load(args.use_resnet18)['state_dict'])
  elif args.use_resnet34:
    model = torchvision.models.resnet34(pretrained=False)
    model.fc = nn.Linear(512, networks.CLASSES)
  elif args.use_resnet50:
    model = torchvision.models.resnet50(pretrained=False)
    model.fc = nn.Linear(2048, networks.CLASSES)
  elif args.use_inception_v3:
    model = torchvision.models.inception_v3(pretrained=False, aux_logits=False)
    model.fc = nn.Linear(2048, networks.CLASSES)
  elif args.use_photon_net:
    model = networks.ResNetContrast(BasicBlock, [2, 2, 2, 2], networks.CLASSES)
    if(args.use_photon_net!="random"):
        model.load_state_dict(torch.load(args.use_photon_net)['state_dict'])
#  elif args.use_contrastive_allfeats:
#    model = networks.ResNetContrast2(BasicBlock, [2, 2, 2, 2], networks.CLASSES)
#    if(args.use_contrastive_allfeats!="random"):
#        model.load_state_dict(torch.load(args.use_contrastive_allfeats)['state_dict'])
  elif args.train_denoiser:
    model = UNet(3,3)
  elif args.use_dirty_pixel:
    model = torchvision.models.resnet18(pretrained=False)
    model.fc = nn.Linear(512, networks.CLASSES)
    model_clean = UNet(3,3)
    if(args.evaluate==False):
        model_clean.load_state_dict(torch.load(args.use_dirty_pixel)['state_dict'])
        model_clean = model_clean.cuda(args.gpu)
  elif args.use_student_teacher:
    model = networks.ResNetPerceptual(BasicBlock, [2, 2, 2, 2], networks.CLASSES)
    model_teacher = networks.ResNetPerceptual(BasicBlock, [2, 2, 2, 2], networks.CLASSES, teacher_model=True)
    model_teacher.load_state_dict(torch.load(args.use_student_teacher)['state_dict'])
    model_teacher = model_teacher.cuda(args.gpu)
    model_teacher.eval()
    for param in model_teacher.parameters():
      param.requires_grad = False
  else:
    print("select correct model")
    exit(0)

  criterion1 = nn.CrossEntropyLoss()
  if(args.use_student_teacher or args.train_denoiser or args.use_dirty_pixel):
    criterion2 = nn.MSELoss()
  else:
    ps = AllPositivePairSelector(balance=False)
    criterion2 = OnlineContrastiveLoss(1., ps)
  # put everthing to gpu
  if args.gpu >= 0:
    model = model.cuda(args.gpu)
    criterion1 = criterion1.cuda(args.gpu)
    #criterion3 = criterion3.cuda(args.gpu)
    criterion2 = criterion2.cuda(args.gpu)
  criterion = [criterion1, criterion2]
  #criterion = [criterion1]
  # setup the optimizer
  opt_params = model.parameters()
  if(args.use_dirty_pixel):
    opt_params = list(model.parameters()) + list(model_clean.parameters())
  optimizer = torch.optim.SGD(opt_params, args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay)

  # resume from a checkpoint?
  if args.resume:
    if os.path.isfile(args.resume):
      print("=> loading checkpoint '{}'".format(args.resume))
      if(args.gpu>=0):
        checkpoint = torch.load(args.resume)
      else:
        checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)
      #best_acc1 = checkpoint['best_acc1']


      #new_state_dict = OrderedDict()
      #model_dict = model.state_dict()
      #for k, v in checkpoint['state_dict'].items():
      #  name = k[7:] # remove `module.`
      #  if(name.startswith('fc')):
      #      continue
      #  new_state_dict[name] = v
      #model_dict.update(new_state_dict)
      #model.load_state_dict(model_dict)
      model.load_state_dict(checkpoint['state_dict'])
      if args.gpu < 0:
        model = model.cpu()
      else:
        model = model.cuda(args.gpu)
      if(args.use_dirty_pixel):
        model_clean.load_state_dict(checkpoint['model_clean_state_dict'])
        model_clean = model_clean.cuda(args.gpu)
#      # only load the optimizer if necessary
#      if (not args.evaluate):
#        args.start_epoch = checkpoint['epoch']
#        optimizer.load_state_dict(checkpoint['optimizer'])
      print("=> loaded checkpoint '{}' (epoch {}, acc1 {})"
          .format(args.resume, checkpoint['epoch'], best_acc1))
    else:
      print("=> no checkpoint found at '{}'".format(args.resume))


  # setup dataset and dataloader
  val_dataset = IMMetricLoader(args.data_folder,
                  split='val', transforms=val_transforms, image_id=False, pil_loader=args.pil_loader, clean_image=args.train_denoiser, label_file=args.val_label_file)

  print('Validation Set Size: ', len(val_dataset))

  val_batch_size = 1 if args.train_denoiser else 50
  val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=val_batch_size, shuffle=False,
    num_workers=args.workers, pin_memory=True, sampler=None, drop_last=False)
  val_dataset.reset_seed()
  # evaluation
  if args.evaluate:
    print("Testing the model ...")
    cudnn.deterministic = True
    validate_model(val_loader, model, -1, args, writer, model_clean)
    return
  load_clean_image = args.use_student_teacher or args.train_denoiser or args.use_dirty_pixel
  train_dataset = IMMetricLoader(args.data_folder,
                  split='train', transforms=train_transforms, label_file=args.label_file, pil_loader=args.pil_loader, clean_image=load_clean_image)
  print('Training Set Size: ', len(train_dataset))
  if(args.num_instances):
    train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size,
    num_workers=args.workers, pin_memory=True, sampler=RandomIdentitySampler(train_dataset, num_instances=args.num_instances), drop_last=True)
  else:
    train_loader = torch.utils.data.DataLoader(
      train_dataset, batch_size=args.batch_size, shuffle=True,
      num_workers=args.workers, pin_memory=True, sampler=None, drop_last=True)

  # enable cudnn benchmark
  cudnn.enabled = True
  cudnn.benchmark = True


  if(args.train_denoiser):
    
    print("Training denoiser ...")
    for epoch in range(args.start_epoch, args.epochs):
      train_dataset.reset_seed()
      train_denoiser(train_loader, val_loader, model, criterion, optimizer, epoch, args, writer)
    return

  model.eval()
  top1 = AverageMeter()
  top5 = AverageMeter()
  val_acc1 = validate_model(val_loader, model, 0, args, writer, model_clean)
  writer.add_scalars('data/top1_accuracy',
     {"train" : top1.avg}, 0)
  writer.add_scalars('data/top5_accuracy',
     {"train" : top5.avg}, 0)
  model.train()

  # warmup the training
  if (args.start_epoch == 0) and (args.warmup_epochs > 0):
    print("Warmup the training ...")
    for epoch in range(0, args.warmup_epochs):
      acc1 = train_model(train_loader, val_loader, model, criterion, optimizer, epoch, "warmup", best_acc1, args, writer, model_clean, model_teacher)

  # start the training
  print("Training the model ...")
  for epoch in range(args.start_epoch, args.epochs):
    train_dataset.reset_seed()
    # train for one epoch
    acc1 = train_model(train_loader, val_loader, model, criterion, optimizer, epoch, "train", best_acc1, args, writer, model_clean, model_teacher)


    # save checkpoint
    best_acc1 = max(acc1, best_acc1)
示例#21
0
    # We do not expect these to be non-zero for an accurate mask,
    # so this should not harm the score.
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] = runs[1::2] - runs[:-1:2]
    return runs


def submit(net):
    """Used for Kaggle submission: predicts and encode all test images"""
    dir = 'data/test/'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    N = len(list(os.listdir(dir)))
    with open('SUBMISSION.csv', 'a') as f:
        f.write('img,rle_mask\n')
        for index, i in enumerate(os.listdir(dir)):
            print('{}/{}'.format(index, N))

            img = Image.open(dir + i)

            mask = predict_img(net, img, device)
            enc = rle_encode(mask)
            f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))


if __name__ == '__main__':
    net = UNet(3, 1).cuda()
    net.load_state_dict(torch.load('MODEL.pth'))
    submit(net)
示例#22
0
class Noise2Noise(object):
    """Implementation of Noise2Noise from Lehtinen et al. (2018)."""

    def __init__(self, params, trainable):
        """Initializes model."""

        self.p = params
        self.trainable = trainable
        self._compile()


    def _compile(self):
        """Compiles model (architecture, loss function, optimizers, etc.)."""

        print('Noise2Noise: Learning Image Restoration without Clean Data (Lethinen et al., 2018)')

        # Model (3x3=9 channels for Monte Carlo since it uses 3 HDR buffers)
        if self.p.noise_type == 'mc':
            self.is_mc = True
            self.model = UNet(in_channels=9)
        else:
            self.is_mc = False
            self.model = UNet(in_channels=3)

        # Set optimizer and loss, if in training mode
        if self.trainable:
            self.optim = Adam(self.model.parameters(),
                              lr=self.p.learning_rate,
                              betas=self.p.adam[:2],
                              eps=self.p.adam[2])

            # Learning rate adjustment
            self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optim,
                patience=self.p.nb_epochs/4, factor=0.5, verbose=True)

            # Loss function
            if self.p.loss == 'hdr':
                assert self.is_mc, 'Using HDR loss on non Monte Carlo images'
                self.loss = HDRLoss()
            elif self.p.loss == 'l2':
                self.loss = nn.MSELoss()
            else:
                self.loss = nn.L1Loss()

        # CUDA support
        self.use_cuda = torch.cuda.is_available() and self.p.cuda
        if self.use_cuda:
            self.model = self.model.cuda()
            if self.trainable:
                self.loss = self.loss.cuda()


    def _print_params(self):
        """Formats parameters to print when training."""

        print('Training parameters: ')
        self.p.cuda = self.use_cuda
        param_dict = vars(self.p)
        pretty = lambda x: x.replace('_', ' ').capitalize()
        print('\n'.join('  {} = {}'.format(pretty(k), str(v)) for k, v in param_dict.items()))
        print()


    def save_model(self, epoch, stats, first=False):
        """Saves model to files; can be overwritten at every epoch to save disk space."""

        # Create directory for model checkpoints, if nonexistent
        if first:
            if self.p.clean_targets:
                ckpt_dir_name = f'{datetime.now():{self.p.noise_type}-clean-%H%M}'
            else:
                ckpt_dir_name = f'{datetime.now():{self.p.noise_type}-%H%M}'
            if self.p.ckpt_overwrite:
                if self.p.clean_targets:
                    ckpt_dir_name = f'{self.p.noise_type}-clean'
                else:
                    ckpt_dir_name = self.p.noise_type

            self.ckpt_dir = os.path.join(self.p.ckpt_save_path, ckpt_dir_name)
            if not os.path.isdir(self.p.ckpt_save_path):
                os.mkdir(self.p.ckpt_save_path)
            if not os.path.isdir(self.ckpt_dir):
                os.mkdir(self.ckpt_dir)

        # Save checkpoint dictionary
        if self.p.ckpt_overwrite:
            fname_unet = '{}/n2n-{}.pt'.format(self.ckpt_dir, self.p.noise_type)
        else:
            valid_loss = stats['valid_loss'][epoch]
            fname_unet = '{}/n2n-epoch{}-{:>1.5f}.pt'.format(self.ckpt_dir, epoch + 1, valid_loss)
        print('Saving checkpoint to: {}\n'.format(fname_unet))
        torch.save(self.model.state_dict(), fname_unet)

        # Save stats to JSON
        fname_dict = '{}/n2n-stats.json'.format(self.ckpt_dir)
        with open(fname_dict, 'w') as fp:
            json.dump(stats, fp, indent=2)


    def load_model(self, ckpt_fname):
        """Loads model from checkpoint file."""

        print('Loading checkpoint from: {}'.format(ckpt_fname))
        if self.use_cuda:
            self.model.load_state_dict(torch.load(ckpt_fname))
        else:
            self.model.load_state_dict(torch.load(ckpt_fname, map_location='cpu'))


    def _on_epoch_end(self, stats, train_loss, epoch, epoch_start, valid_loader):
        """Tracks and saves starts after each epoch."""

        # Evaluate model on validation set
        print('\rTesting model on validation set... ', end='')
        epoch_time = time_elapsed_since(epoch_start)[0]
        valid_loss, valid_time, valid_psnr = self.eval(valid_loader)
        show_on_epoch_end(epoch_time, valid_time, valid_loss, valid_psnr)

        # Decrease learning rate if plateau
        self.scheduler.step(valid_loss)

        # Save checkpoint
        stats['train_loss'].append(train_loss)
        stats['valid_loss'].append(valid_loss)
        stats['valid_psnr'].append(valid_psnr)
        self.save_model(epoch, stats, epoch == 0)

        # Plot stats
        if self.p.plot_stats:
            loss_str = f'{self.p.loss.upper()} loss'
            plot_per_epoch(self.ckpt_dir, 'Valid loss', stats['valid_loss'], loss_str)
            plot_per_epoch(self.ckpt_dir, 'Valid PSNR', stats['valid_psnr'], 'PSNR (dB)')


    def test(self, test_loader, show):
        """Evaluates denoiser on test set."""

        self.model.train(False)

        source_imgs = []
        denoised_imgs = []
        clean_imgs = []

        # Create directory for denoised images
        denoised_dir = os.path.dirname(self.p.data)
        save_path = os.path.join(denoised_dir, 'denoised')
        if not os.path.isdir(save_path):
            os.mkdir(save_path)

        for batch_idx, (source, target) in enumerate(test_loader):
            # Only do first <show> images
            if show == 0 or batch_idx >= show:
                break

            source_imgs.append(source)
            clean_imgs.append(target)

            if self.use_cuda:
                source = source.cuda()

            # Denoise
            denoised_img = self.model(source).detach()
            denoised_imgs.append(denoised_img)

        # Squeeze tensors
        source_imgs = [t.squeeze(0) for t in source_imgs]
        denoised_imgs = [t.squeeze(0) for t in denoised_imgs]
        clean_imgs = [t.squeeze(0) for t in clean_imgs]

        # Create montage and save images
        print('Saving images and montages to: {}'.format(save_path))
        for i in range(len(source_imgs)):
            img_name = test_loader.dataset.imgs[i]
            create_montage(img_name, self.p.noise_type, save_path, source_imgs[i], denoised_imgs[i], clean_imgs[i], show)

    #ÕûͼÊä³ö
    def test2(self):
        """Evaluates denoiser on test set."""

        self.model.train(False)

        # Create directory for denoised images
        denoised_dir = os.path.dirname(self.p.data)
        save_path = os.path.join(denoised_dir, 'denoised')
        if not os.path.isdir(save_path):
            os.mkdir(save_path)

        namelist = os.listdir(denoised_dir)
        filelist=[os.path.join(denoised_dir,name) for name in namelist if (name != "denoised")]

        # Load PIL image
        for img_path in filelist:
            img =  Image.open(img_path).convert('RGB')
            w, h = img.size
            if w % 32 != 0:
                w = (w//32)*32
            if h % 32 != 0:
                h = (h//32)*32
            img = tvF.crop(img, 0, 0, h, w)
            source = tvF.to_tensor(img)
            source = source.unsqueeze(0)
            print(source.size())
            if self.use_cuda:
                source = source.cuda()
            # Denoise
            denoised = self.model(source).detach()
            denoised = denoised.cpu()
            denoised = denoised.squeeze(0)
            print("--------->",denoised.size())
            denoised = tvF.to_pil_image(denoised)
            fname = os.path.basename(img_path)
            denoised.save(os.path.join(save_path, f'{fname}-denoised.png'))

    def eval(self, valid_loader):
        """Evaluates denoiser on validation set."""

        self.model.train(False)

        valid_start = datetime.now()
        loss_meter = AvgMeter()
        psnr_meter = AvgMeter()

        for batch_idx, (source, target) in enumerate(valid_loader):
            if self.use_cuda:
                source = source.cuda()
                target = target.cuda()

            # Denoise
            source_denoised = self.model(source)

            # Update loss
            loss = self.loss(source_denoised, target)
            loss_meter.update(loss.item())

            # Compute PSRN
            if self.is_mc:
                source_denoised = reinhard_tonemap(source_denoised)
            # TODO: Find a way to offload to GPU, and deal with uneven batch sizes
            for i in range(self.p.batch_size):
                source_denoised = source_denoised.cpu()
                target = target.cpu()
                psnr_meter.update(psnr(source_denoised[i], target[i]).item())

        valid_loss = loss_meter.avg
        valid_time = time_elapsed_since(valid_start)[0]
        psnr_avg = psnr_meter.avg

        return valid_loss, valid_time, psnr_avg


    def train(self, train_loader, valid_loader):
        """Trains denoiser on training set."""

        self.model.train(True)

        self._print_params()
        num_batches = len(train_loader)
        if self.p.report_interval == 0:
            self.p.report_interval = num_batches
        print("--------->num_batches:",num_batches)    
        print("--------->report_interval:",self.p.report_interval)  
        assert num_batches % self.p.report_interval == 0, 'Report interval must divide total number of batches'

        # Dictionaries of tracked stats
        stats = {'noise_type': self.p.noise_type,
                 'noise_param': self.p.noise_param,
                 'train_loss': [],
                 'valid_loss': [],
                 'valid_psnr': []}

        # Main training loop
        train_start = datetime.now()
        for epoch in range(self.p.nb_epochs):
            print('EPOCH {:d} / {:d}'.format(epoch + 1, self.p.nb_epochs))

            # Some stats trackers
            epoch_start = datetime.now()
            train_loss_meter = AvgMeter()
            loss_meter = AvgMeter()
            time_meter = AvgMeter()

            # Minibatch SGD
            for batch_idx, (source, target) in enumerate(train_loader):
                batch_start = datetime.now()
                progress_bar(batch_idx, num_batches, self.p.report_interval, loss_meter.val)

                if self.use_cuda:
                    source = source.cuda()
                    target = target.cuda()

                # Denoise image
                source_denoised = self.model(source)

                loss = self.loss(source_denoised, target)
                loss_meter.update(loss.item())

                # Zero gradients, perform a backward pass, and update the weights
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()

                # Report/update statistics
                time_meter.update(time_elapsed_since(batch_start)[1])
                if (batch_idx + 1) % self.p.report_interval == 0 and batch_idx:
                    show_on_report(batch_idx, num_batches, loss_meter.avg, time_meter.avg)
                    train_loss_meter.update(loss_meter.avg)
                    loss_meter.reset()
                    time_meter.reset()

            # Epoch end, save and reset tracker
            self._on_epoch_end(stats, train_loss_meter.avg, epoch, epoch_start, valid_loader)
            train_loss_meter.reset()

        train_elapsed = time_elapsed_since(train_start)[0]
        print('Training done! Total elapsed time: {}\n'.format(train_elapsed))
示例#23
0
dataset_dir = 'parking/'
batch_size = 1

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = parking_data(root=dataset_dir, l=0.05, is_train=False, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)

# using GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

net = UNet(n_channels=1, n_classes=1)
net = net.to(device)

net.load_state_dict(torch.load('./unet/ckpt/parking_unet_model_epoch40.pth'))

net.eval()
dummy = net(torch.zeros(1,1,320,320).to(device))

############## Character Recognition ##############
# Initialize the network
global_step = tf.Variable(0, trainable=False)
logits, input_plate, seq_len = get_train_model(num_channels, label_len, BATCH_SIZE, img_size)
logits = tf.transpose(logits, (1, 0, 2))
decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len, merge_repeated=False)
session = tf.Session()
init = tf.global_variables_initializer()
    val_loader = DataLoader(val_set,
                            batch_size=args.batchsize,
                            shuffle=False,
                            num_workers=args.num_workers)

    test_set = data_loader(img_tests,
                           transform=data_transforms['val'],
                           APS=args.APS,
                           isTrain=False)
    print('Number of test patches extracted: ', len(test_set))
    test_loader = DataLoader(test_set,
                             batch_size=args.batchsize,
                             shuffle=False,
                             num_workers=args.num_workers)

    net = UNet(n_channels=3, n_classes=args.n_classes, bilinear=False)

    if args.gpu:
        net.cuda()
        net = torch.nn.DataParallel(net, device_ids=[0, 1])
        cudnn.benchmark = True  # faster convolutions, but more memory
    try:
        train_net(net=net,
                  train_loader=train_loader,
                  val_loader=val_loader,
                  test_loader=test_loader,
                  args=args)

    except KeyboardInterrupt:
        torch.save(net.state_dict(),
                   'INTERRUPTED_res{}.pth'.format(args.resolution))
示例#25
0
        # true_mask = torch.from_numpy(true_mask)

        mask_pred = net(img.cuda())[0]
        mask_pred = mask_pred.data.cpu().numpy()
        mask_pred = (mask_pred > 0.5).astype(int)
        #
        # print('mask_pred.shape.zhaojin', mask_pred.shape)
        # print('true_mask.shape.zhaojin', true_mask.shape)

        tot += dice_cofe(mask_pred, true_mask)
    return tot / (i + 1)


if __name__ == "__main__":

    dir_img = '/home/zhaojin/data/TacomaBridge/segdata/train/img'
    dir_mask = '/home/zhaojin/data/TacomaBridge/segdata/train/mask'
    dir_checkpoint = '/home/zhaojin/data/TacomaBridge/segdata/train/checkpoint/logloss_softmax/'

    net = UNet(n_channels=1, n_classes=4)
    net = net.cuda()
    ids = get_ids(dir_img)
    ids = split_ids(ids)

    iddataset = split_train_val(ids, 0.1)

    val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, 0.5)
    if 1:
        val_dice = eval_net(net, val)
        print('Validation Dice Coeff: {}'.format(val_dice))
示例#26
0
def save_ckp(state):
    f_path = "/media/disk2/sombit/kitti_seg/checkpoint.pt"
    torch.save(state, f_path)


def load_ckp(checkpoint_fpath, model, optimizer):
    checkpoint = torch.load(checkpoint_fpath)
    model.load_state_dict(checkpoint['state_dict'])
    optim.load_state_dict(checkpoint['optimizer'])
    return model, optim, checkpoint['epoch']


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = UNet(padding=True, up_mode='upsample').to(device)
print("Load Model")
optim = torch.optim.Adam(model.parameters())
# data_loader = get_loader('kitti','seg')
# data_path = "/home/sombit/kitti"
t_loader = data_loader(
    is_transform=True,
    img_norm=False,
)

trainloader = data.DataLoader(t_loader,
                              batch_size=1,
                              num_workers=2,
                              shuffle=True)

epochs = 100
示例#27
0
        print("Error : Input files and output files are not of the same length")
        raise SystemExit()
    else:
        out_files = args.output

    return out_files

def mask_to_image(mask):
    return Image.fromarray((mask * 255).astype(np.uint8))

if __name__ == "__main__":
    args = get_args()
    in_files = args.input
    out_files = get_output_filenames(args)

    net = UNet(n_channels=3, n_classes=1)

    print("Loading model {}".format(args.model))

    if not args.cpu:
        print("Using CUDA version of the net, prepare your GPU !")
        net.cuda()
        net.load_state_dict(torch.load(args.model))
    else:
        net.cpu()
        net.load_state_dict(torch.load(args.model, map_location='cpu'))
        print("Using CPU version of the net, this may be very slow")

    print("Model loaded !")

    for i, fn in enumerate(in_files):
示例#28
0
    args = edict()
    args.epochs = cfg.ARGS.EPOCHS
    args.batchsize = cfg.ARGS.BATCH_SIZE
    args.lr = cfg.ARGS.LR
    args.load_pth = cfg.ARGS.LOAD_PTH

    args.expname = opt.config.replace('.yml', '')
    args.scale = cfg.EXP.SCALE
    args.colormap = cfg.EXP.COLORMAP

    logging.basicConfig(level=logging.INFO,
                        format='%(levelname)s: %(message)s')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    net = UNet(n_channels=3, n_classes=2)
    logging.info(
        f'Network:\n'
        f'\t{net.n_channels} input channels\n'
        f'\t{net.n_classes} output channels (classes)\n'
        f'\t{"Bilinear" if net.bilinear else "Dilated conv"} upscaling')

    if args.load_pth:
        net.load_state_dict(torch.load(args.load_pth, map_location=device))
        logging.info(f'Model loaded from {args.load_pth}')

    net.to(device=device)

    try:
        train_net(net=net,
                  epochs=args.epochs,
示例#29
0
文件: train.py 项目: lighTQ/torchUNet
                      default=False, help='load file model')
    parser.add_option('-s', '--scale', dest='scale', type='float',
                      default=0.5, help='downscaling factor of the images')
    parser.add_option('-t', '--height', dest='height', type='int',
                      default=1024, help='rescale images to height')


    (options, args) = parser.parse_args()
    return options

if __name__ == '__main__':
    
    print("Let's record it \n")
    args = get_args()

    net = UNet(n_channels=3, n_classes=1)
    net = torch.nn.DataParallel(net)

    if args.load:
        net.load_state_dict(torch.load(args.load))
        print('Model loaded from {}'.format(args.load))

    if args.gpu:
        net.cuda()
        cudnn.benchmark = True # faster convolutions, but more memory

    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
示例#30
0
def main():
    opt = get_opt()
    print(opt)
    print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))

    n_gpu = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
    opt.distributed = n_gpu > 1
    local_rank = opt.local_rank

    if opt.distributed:
        torch.cuda.set_device(opt.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        synchronize()

    # create dataset
    train_dataset = CPDataset(opt)

    # create dataloader
    train_loader = CPDataLoader(opt, train_dataset)

    # visualization
    if not os.path.exists(opt.tensorboard_dir):
        os.makedirs(opt.tensorboard_dir)

    board = None
    if single_gpu_flag(opt):
        board = SummaryWriter(
            log_dir=os.path.join(opt.tensorboard_dir, opt.name))

    # create model & train & save the final checkpoint
    if opt.stage == 'GMM':
        model = GMM(opt)
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_gmm(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'TOM':

        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        model.cuda()
        # if opt.distributed:
        #     model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module

        train_tom(opt, train_loader, model, model_module, gmm_model, board)
        if single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))
    elif opt.stage == 'TOM+WARP':

        gmm_model = GMM(opt)
        gmm_model.cuda()

        model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
        model.cuda()
        # if opt.distributed:
        #     model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        gmm_model_module = gmm_model
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            gmm_model = torch.nn.parallel.DistributedDataParallel(
                gmm_model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            gmm_model_module = gmm_model.module

        train_tom_gmm(opt, train_loader, model, model_module, gmm_model,
                      gmm_model_module, board)
        if single_gpu_flag(opt):
            save_checkpoint(
                model_module,
                os.path.join(opt.checkpoint_dir, opt.name, 'tom_final.pth'))

    elif opt.stage == "identity":
        model = Embedder()
        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
        train_identity_embedding(opt, train_loader, model, board)
        save_checkpoint(
            model, os.path.join(opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
    elif opt.stage == 'residual':

        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        generator_model = UnetGenerator(25,
                                        4,
                                        6,
                                        ngf=64,
                                        norm_layer=nn.InstanceNorm2d)
        load_checkpoint(generator_model,
                        "checkpoints/tom_train_new/step_038000.pth")
        generator_model.cuda()

        embedder_model = Embedder()
        load_checkpoint(embedder_model,
                        "checkpoints/identity_train_64_dim/step_020000.pth")
        embedder_model = embedder_model.embedder_b.cuda()

        model = UNet(n_channels=4, n_classes=3)
        if opt.distributed:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model.apply(utils.weights_init('kaiming'))
        model.cuda()

        if opt.use_gan:
            discriminator = Discriminator()
            discriminator.apply(utils.weights_init('gaussian'))
            discriminator.cuda()

            acc_discriminator = AccDiscriminator()
            acc_discriminator.apply(utils.weights_init('gaussian'))
            acc_discriminator.cuda()

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)
            if opt.use_gan:
                load_checkpoint(discriminator,
                                opt.checkpoint.replace("step_", "step_disc_"))

        model_module = model
        if opt.use_gan:
            discriminator_module = discriminator
            acc_discriminator_module = acc_discriminator

        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            if opt.use_gan:
                discriminator = torch.nn.parallel.DistributedDataParallel(
                    discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                discriminator_module = discriminator.module

                acc_discriminator = torch.nn.parallel.DistributedDataParallel(
                    acc_discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                acc_discriminator_module = acc_discriminator.module

        if opt.use_gan:
            train_residual(opt,
                           train_loader,
                           model,
                           model_module,
                           gmm_model,
                           generator_model,
                           embedder_model,
                           board,
                           discriminator=discriminator,
                           discriminator_module=discriminator_module,
                           acc_discriminator=acc_discriminator,
                           acc_discriminator_module=acc_discriminator_module)

            if single_gpu_flag(opt):
                save_checkpoint(
                    {
                        "generator": model_module,
                        "discriminator": discriminator_module
                    },
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
        else:
            train_residual(opt, train_loader, model, model_module, gmm_model,
                           generator_model, embedder_model, board)
            if single_gpu_flag(opt):
                save_checkpoint(
                    model_module,
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
    elif opt.stage == "residual_old":
        gmm_model = GMM(opt)
        load_checkpoint(gmm_model, "checkpoints/gmm_train_new/step_020000.pth")
        gmm_model.cuda()

        generator_model = UnetGenerator(25,
                                        4,
                                        6,
                                        ngf=64,
                                        norm_layer=nn.InstanceNorm2d)
        load_checkpoint(generator_model,
                        "checkpoints/tom_train_new_2/step_070000.pth")
        generator_model.cuda()

        embedder_model = Embedder()
        load_checkpoint(embedder_model,
                        "checkpoints/identity_train_64_dim/step_020000.pth")
        embedder_model = embedder_model.embedder_b.cuda()

        model = UNet(n_channels=4, n_classes=3)
        if opt.distributed:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        model.apply(utils.weights_init('kaiming'))
        model.cuda()

        if opt.use_gan:
            discriminator = Discriminator()
            discriminator.apply(utils.weights_init('gaussian'))
            discriminator.cuda()

        if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
            load_checkpoint(model, opt.checkpoint)

        model_module = model
        if opt.use_gan:
            discriminator_module = discriminator
        if opt.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_rank],
                output_device=local_rank,
                find_unused_parameters=True)
            model_module = model.module
            if opt.use_gan:
                discriminator = torch.nn.parallel.DistributedDataParallel(
                    discriminator,
                    device_ids=[local_rank],
                    output_device=local_rank,
                    find_unused_parameters=True)
                discriminator_module = discriminator.module

        if opt.use_gan:
            train_residual_old(opt,
                               train_loader,
                               model,
                               model_module,
                               gmm_model,
                               generator_model,
                               embedder_model,
                               board,
                               discriminator=discriminator,
                               discriminator_module=discriminator_module)
            if single_gpu_flag(opt):
                save_checkpoint(
                    {
                        "generator": model_module,
                        "discriminator": discriminator_module
                    },
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
        else:
            train_residual_old(opt, train_loader, model, model_module,
                               gmm_model, generator_model, embedder_model,
                               board)
            if single_gpu_flag(opt):
                save_checkpoint(
                    model_module,
                    os.path.join(opt.checkpoint_dir, opt.name,
                                 'tom_final.pth'))
    else:
        raise NotImplementedError('Model [%s] is not implemented' % opt.stage)

    print('Finished training %s, nameed: %s!' % (opt.stage, opt.name))
cuda = opt.cuda
if cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")

torch.manual_seed(opt.seed)
if cuda:
    torch.cuda.manual_seed(opt.seed)

print('===> Loading datasets')
train_set = get_training_set(opt.size + opt.remsize, target_mode=opt.target_mode, colordim=opt.colordim)
test_set = get_test_set(opt.size, target_mode=opt.target_mode, colordim=opt.colordim)
training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True)
testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False)

print('===> Building unet')
unet = UNet(opt.colordim)


criterion = nn.MSELoss()
if cuda:
    unet = unet.cuda()
    criterion = criterion.cuda()

pretrained = True
if pretrained:
    unet.load_state_dict(torch.load(opt.pretrain_net))

optimizer = optim.SGD(unet.parameters(), lr=opt.lr)
print('===> Training unet')

def train(epoch):
示例#32
0
 def make_discriminator(self, discriminator_args: Dict):
     return ScoreWrapper(UNet(**discriminator_args))