Exemplo n.º 1
0
def main():
    train_root_dir = '/content/drive/My Drive/DDSM/train/CBIS-DDSM'
    test_root_dir = '/content/drive/My Drive/DDSM/test/CBIS-DDSM'
    path_weights = '/content/drive/My Drive/Cv/weights'
    batch_size = 3
    valid_size = 0.2
    nb_epochs = 20
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # data loaders
    loaders = dataloaders(train_root_dir, combined_transform, batch_size, valid_size)

    model = UNet(in_channels=3, out_channels=1)
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.3)

    model = train(model, optimizer, exp_lr_scheduler, loaders, nb_epochs, device, path_weights)
    # from torchsummary import summary
    #
    # summary(model, input_size=(3, 224, 224))

    # test_transform = transforms.Compose([
    #     transforms.ToTensor(),
    #     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    # ])
    test_loader = DataLoader(
        MassSegmentationDataset(test_root_dir, combined_transform),
        batch_size=batch_size,
        num_workers=0
    )

    test(model, test_loader, device)
Exemplo n.º 2
0
def main():
    args = parser.parse_args()
    
    dataset = SyntheticCellDataset(arg.img_dir, arg.mask_dir)
    
    indices = torch.randperm(len(dataset)).tolist()
    sr = int(args.split_ratio * len(dataset))
    train_set = torch.utils.data.Subset(dataset, indices[:-sr])
    val_set = torch.utils.data.Subset(dataset, indices[-sr:])
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, pin_memory=True)
    
    device = torch.device("cpu" if not args.use_cuda else "cuda:0")
    
    model = UNet()
    model.to(device)
    
    dsc_loss = DiceLoss()
    
    optimizer = torch.optim.Adam(model.parameters(), args.lr)
    
    val_overall = 1000
    for epoch in args.N_epoch:
        model, train_loss, optimizer = train(model, train_loader, device, optimizer)
        val_loss = validate(model, val_loader, device)
        
        if val_loss < val_overall:
            save_checkpoint(args.model_save_dir + '/epoch_'+str(epoch+1), model, train_loss, val_loss, epoch)
            val_overall = val_loss
            
        print('[{}/{}] train loss :{} val loss : {}'.format(epoch+1, num_epoch, train_loss, val_loss))
    print('Training completed)
Exemplo n.º 3
0
def main(conf):
    device = "cuda:0" if torch.cuda.is_available() else 'cpu'
    beta_schedule = "linear"
    beta_start = 1e-4
    beta_end = 2e-2
    n_timestep = 1000

    conf.distributed = dist.get_world_size() > 1

    transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
        ]
    )

    train_set = MultiResolutionDataset(
        conf.dataset.path, transform, conf.dataset.resolution
    )
    train_sampler = dist.data_sampler(
        train_set, shuffle=True, distributed=conf.distributed
    )
    train_loader = conf.training.dataloader.make(train_set, sampler=train_sampler)

    model = UNet(
        conf.model.in_channel,
        conf.model.channel,
        channel_multiplier=conf.model.channel_multiplier,
        n_res_blocks=conf.model.n_res_blocks,
        attn_strides=conf.model.attn_strides,
        dropout=conf.model.dropout,
        fold=conf.model.fold,
    )
    model = model.to(device)
    ema = UNet(
        conf.model.in_channel,
        conf.model.channel,
        channel_multiplier=conf.model.channel_multiplier,
        n_res_blocks=conf.model.n_res_blocks,
        attn_strides=conf.model.attn_strides,
        dropout=conf.model.dropout,
        fold=conf.model.fold,
    )
    ema = ema.to(device)

    if conf.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = conf.training.optimizer.make(model.parameters())
    scheduler = conf.training.scheduler.make(optimizer)

    betas = make_beta_schedule(beta_schedule, beta_start, beta_end, n_timestep)
    diffusion = GaussianDiffusion(betas).to(device)

    train(conf, train_loader, model, ema, diffusion, optimizer, scheduler, device)
Exemplo n.º 4
0
def train():
    ex = wandb.init(project="PQRST-segmentation")
    ex.config.setdefaults(wandb_config)

    logging.basicConfig(level=logging.INFO,
                        format="%(levelname)s: %(message)s")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    net = UNet(in_ch=1, out_ch=4)
    net.to(device)

    try:
        train_model(net=net,
                    device=device,
                    batch_size=wandb.config.batch_size,
                    lr=wandb.config.lr,
                    epochs=wandb.config.epochs)
    except KeyboardInterrupt:
        try:
            save = input("save?(y/n)")
            if save == "y":
                torch.save(net.state_dict(), 'net_params.pkl')
            sys.exit(0)
        except SystemExit:
            os._exit(0)
Exemplo n.º 5
0
def train():
    # Load the data sets
    train_dataset = NucleusDataset(
        "data",
        train=True,
        transform=Compose([Rescale(256), ToTensor()]),
        target_transform=Compose([Rescale(256), ToTensor()]))

    # Use cuda if available
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Set model to GPU/CPU
    if args.from_checkpoint:
        model = UNet.load(args.from_checkpoint)
    else:
        model = UNet()
    model.to(device)

    # Initialize optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

    # Initialize trainer
    trainer = Trainer(dataset=train_dataset,
                      model=model,
                      optimizer=optimizer,
                      batch_size=args.batch_size,
                      device=args.device,
                      output_dir=output_dir)

    # Run the training
    trainer.run_train_loop(epochs=args.epochs)
Exemplo n.º 6
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--file_paths', default="data/files.txt")
    parser.add_argument('--landmark_paths', default="data/landmarks.txt")
    parser.add_argument('--landmark', type=int, default=0)
    parser.add_argument('--save_path')
    parser.add_argument('--num_epochs', type=int, default=int(1e9))
    parser.add_argument('--log_freq', type=int, default=100)
    parser.add_argument('--separator', default=",")
    parser.add_argument('--batch_size', type=int, default=8)
    args = parser.parse_args()

    file_paths = args.file_paths
    landmark_paths = args.landmark_paths
    landmark_wanted = args.landmark
    num_epochs = args.num_epochs
    log_freq = args.log_freq
    save_path = args.save_path

    x, y = get_data(file_paths,
                    landmark_paths,
                    landmark_wanted,
                    separator=args.separator)
    print(f"Got {len(x)} images with {len(y)} landmarks")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("device", device)

    dataset = TensorDataset(torch.Tensor(x), torch.Tensor(y))
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    unet = UNet(in_dim=1, out_dim=6, num_filters=4)
    criterion = torch.nn.CrossEntropyLoss(weight=get_weigths(y))
    optimizer = optim.SGD(unet.parameters(), lr=0.001, momentum=0.9)

    unet.to(device)

    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(dataloader):
            inputs, labels = data
            optimizer.zero_grad()

            outputs = unet(inputs)
            loss = criterion(outputs, labels.long())
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        print(f"[{epoch+1}/{num_epochs}] loss: {running_loss}")
        if epoch % log_freq == log_freq - 1:
            if save_path is not None:
                torch.save(unet.state_dict(),
                           os.path.join(save_path, f"unet-{epoch}.pt"))
def main():
    args = parser.parse_args()
    device = torch.device("cpu" if not args.use_cuda else "cuda:0")

    model = UNet()
    load_checkpoint(args.weight, model, device='cpu')
    model.to(device)

    img = Image.open(img_file)

    resize = transforms.Resize(size=(576, 576))
    im_r = TF.to_tensor(resize(img))
    im_r = im_r.unsqueeze(0)

    with torch.no_grad():
        pred = model(im_r.to(device))

    pred_mask = pred.detach().cpu().numpy().squeeze()
Exemplo n.º 8
0
def main():
    """
    Training.
    """
    global start_epoch, epoch, checkpoint

    # Initialize model or load checkpoint
    if checkpoint is None:
        model = UNet(in_channels, out_channels)
        # Initialize the optimizer
        optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad,
                                                   model.parameters()),
                                     lr=lr)
    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']

    # Move to default device
    model = model.to(device)
    criterion = nn.L1Loss().to(device)

    # Custom dataloaders
    train_loader = torch.utils.data.DataLoader(TripletDataset(
        train_folder, crop_size, scale),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=workers,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(TripletDataset(
        test_folder, crop_size, scale),
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=workers,
                                              pin_memory=True)

    # Total number of epochs to train for
    epochs = int(iterations // len(train_loader) + 1)

    # Epochs
    for epoch in range(start_epoch, epochs):
        # One epoch's training
        train(train_loader=train_loader,
              model=model,
              criterion=criterion,
              optimizer=optimizer,
              epoch=epoch,
              epochs=epochs)
        test(test_loader=test_loader, model=model, criterion=criterion)

        # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model': model,
            'optimizer': optimizer
        }, f'checkpoints/checkpoint_unet_{epoch}.pth.tar')
Exemplo n.º 9
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--bs', metavar='bs', type=int, default=2)
    parser.add_argument('--path', type=str, default='../../data')
    parser.add_argument('--results', type=str, default='../../results/model')
    parser.add_argument('--nw', type=int, default=0)
    parser.add_argument('--max_images', type=int, default=None)
    parser.add_argument('--val_size', type=int, default=None)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--lr', type=float, default=0.003)
    parser.add_argument('--lr_decay', type=float, default=0.99997)
    parser.add_argument('--kernel_lvl', type=float, default=1)
    parser.add_argument('--noise_lvl', type=float, default=1)
    parser.add_argument('--motion_blur', type=bool, default=False)
    parser.add_argument('--homo_align', type=bool, default=False)
    parser.add_argument('--resume', type=bool, default=False)

    args = parser.parse_args()

    print()
    print(args)
    print()

    if not os.path.isdir(args.results): os.makedirs(args.results)

    PATH = args.results
    if not args.resume:
        f = open(PATH + "/param.txt", "a+")
        f.write(str(args))
        f.close()

    writer = SummaryWriter(PATH + '/runs')

    # CUDA for PyTorch
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda:0' if use_cuda else "cpu")

    # Parameters
    params = {'batch_size': args.bs, 'shuffle': True, 'num_workers': args.nw}

    # Generators
    print('Initializing training set')
    training_set = Dataset(args.path + '/train/', args.max_images,
                           args.kernel_lvl, args.noise_lvl, args.motion_blur,
                           args.homo_align)
    training_generator = data.DataLoader(training_set, **params)

    print('Initializing validation set')
    validation_set = Dataset(args.path + '/test/', args.val_size,
                             args.kernel_lvl, args.noise_lvl, args.motion_blur,
                             args.homo_align)

    validation_generator = data.DataLoader(validation_set, **params)

    # Model
    model = UNet(in_channel=3, out_channel=3)
    if args.resume:
        models_path = get_newest_model(PATH)
        print('loading model from ', models_path)
        model.load_state_dict(torch.load(models_path))

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)

    model.to(device)

    # Loss + optimizer
    criterion = BurstLoss()
    optimizer = RAdam(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=8 // args.bs, gamma=args.lr_decay)
    if args.resume:
        n_iter = np.loadtxt(PATH + '/train.txt', delimiter=',')[:, 0][-1]
    else:
        n_iter = 0

    # Loop over epochs
    for epoch in range(args.epochs):
        train_loss = 0.0

        # Training
        model.train()
        for i, (X_batch, y_labels) in enumerate(training_generator):
            # Alter the burst length for each mini batch

            burst_length = np.random.randint(2, 9)
            X_batch = X_batch[:, :burst_length, :, :, :]

            # Transfer to GPU
            X_batch, y_labels = X_batch.to(device).type(
                torch.float), y_labels.to(device).type(torch.float)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            pred = model(X_batch)
            loss = criterion(pred, y_labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            train_loss += loss.detach().cpu().numpy()
            writer.add_scalar('training_loss', loss.item(), n_iter)

            if i % 100 == 0 and i > 0:
                loss_printable = str(np.round(train_loss, 2))

                f = open(PATH + "/train.txt", "a+")
                f.write(str(n_iter) + "," + loss_printable + "\n")
                f.close()

                print("training loss ", loss_printable)

                train_loss = 0.0

            if i % 1000 == 0:
                if torch.cuda.device_count() > 1:
                    torch.save(
                        model.module.state_dict(),
                        os.path.join(PATH,
                                     'model_' + str(int(n_iter)) + '.pt'))
                else:
                    torch.save(
                        model.state_dict(),
                        os.path.join(PATH,
                                     'model_' + str(int(n_iter)) + '.pt'))

            if i % 1000 == 0:
                # Validation
                val_loss = 0.0
                with torch.set_grad_enabled(False):
                    model.eval()
                    for v, (X_batch,
                            y_labels) in enumerate(validation_generator):
                        # Alter the burst length for each mini batch

                        burst_length = np.random.randint(2, 9)
                        X_batch = X_batch[:, :burst_length, :, :, :]

                        # Transfer to GPU
                        X_batch, y_labels = X_batch.to(device).type(
                            torch.float), y_labels.to(device).type(torch.float)

                        # forward + backward + optimize
                        pred = model(X_batch)
                        loss = criterion(pred, y_labels)

                        val_loss += loss.detach().cpu().numpy()

                        if v < 5:
                            im = make_im(pred, X_batch, y_labels)
                            writer.add_image('image_' + str(v), im, n_iter)

                    writer.add_scalar('validation_loss', val_loss, n_iter)

                    loss_printable = str(np.round(val_loss, 2))
                    print('validation loss ', loss_printable)

                    f = open(PATH + "/eval.txt", "a+")
                    f.write(str(n_iter) + "," + loss_printable + "\n")
                    f.close()

            n_iter += args.bs
Exemplo n.º 10
0
    # n_classes is the number of probabilities you want to get per pixel
    #   - For 1 class and background, use n_classes=1
    #   - For 2 classes, use n_classes=1
    #   - For N > 2 classes, use n_classes=N
    net = UNet(n_channels=1, n_classes=1, bilinear=False, scale=1)
    logging.info(
        f'Network:\n'
        f'\t{net.n_channels} input channels\n'
        f'\t{net.n_classes} output channels (classes)\n'
        f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

    if args.load:
        net.load_state_dict(torch.load(args.load, map_location=device))
        logging.info(f'Model loaded from {args.load}')

    net.to(device=device)
    # faster convolutions, but more memory
    # cudnn.benchmark = True

    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  device=device,
                  img_scale=args.scale,
                  val_percent=args.val / 100)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        logging.info('Saved interrupt')
        try:
Exemplo n.º 11
0
import torch.nn as nn
import cv2
from model import UNet
from glob import glob
import os

parser = argparse.ArgumentParser()
parser.add_argument('--path', default='./checkpoint.pth', help="path to the saved checkpoint of model")
args = parser.parse_args()

filenames = glob('./data/test/*')
filenames.sort()

model = UNet(n_channels=3, bilinear=True)
model.load_state_dict(torch.load(args.path))
model.to('cuda')

with torch.no_grad():
    for i, filename in enumerate(filenames):
        test = cv2.imread(filename)/255.0        
        test = np.expand_dims(test.transpose([2,0,1]), axis=0)
        test = torch.from_numpy(test).to(device="cuda", dtype=torch.float32)

        out = model(test)

        out = out.to(device="cpu").numpy().squeeze()
        out = np.clip(out*255.0, 0, 255)

        path = filename.replace('/test/','/results/')[:-4]+'.png'
        # folder = os.path.dirname(path)
        # if not os.path.exists(folder):
Exemplo n.º 12
0
def main(args):
    writer = SummaryWriter(os.path.join('./logs'))
    # torch.backends.cudnn.benchmark = True
    if not os.path.isdir(args.checkpoint_dir):
        os.mkdir(args.checkpoint_dir)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('[MODEL] CUDA DEVICE : {}'.format(device))

    # TODO DEFINE TRAIN AND TEST TRANSFORMS
    train_tf = None
    test_tf = None

    # Channel wise mean calculated on adobe240-fps training dataset
    mean = [0.429, 0.431, 0.397]
    std = [1, 1, 1]
    normalize = transforms.Normalize(mean=mean, std=std)
    transform = transforms.Compose([transforms.ToTensor(), normalize])

    test_valid = 'validation' if args.valid else 'test'
    train_data = BlurDataset(os.path.join(args.dataset_root, 'train'),
                             seq_len=args.sequence_length,
                             tau=args.num_frame_blur,
                             delta=5,
                             transform=train_tf)
    test_data = BlurDataset(os.path.join(args.dataset_root, test_valid),
                            seq_len=args.sequence_length,
                            tau=args.num_frame_blur,
                            delta=5,
                            transform=train_tf)

    train_loader = DataLoader(train_data,
                              batch_size=args.train_batch_size,
                              shuffle=True)
    test_loader = DataLoader(test_data,
                             batch_size=args.test_batch_size,
                             shuffle=False)

    # TODO IMPORT YOUR CUSTOM MODEL
    model = UNet(3, 3, device, decode_mode=args.decode_mode)

    if args.checkpoint:
        store_dict = torch.load(args.checkpoint)
        try:
            model.load_state_dict(store_dict['state_dict'])
        except KeyError:
            model.load_state_dict(store_dict)

    if args.train_continue:
        store_dict = torch.load(args.checkpoint)
        model.load_state_dict(store_dict['state_dict'])

    else:
        store_dict = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1}

    model.to(device)
    model.train(True)

    # model = nn.DataParallel(model)

    # TODO DEFINE MORE CRITERIA
    # input(True if device == torch.device('cuda:0') else False)
    criterion = {
        'MSE': nn.MSELoss(),
        'L1': nn.L1Loss(),
        # 'Perceptual': PerceptualLoss(model='net-lin', net='vgg', dataparallel=True,
        #                              use_gpu=True if device == torch.device('cuda:0') else False)
    }

    criterion_w = {'MSE': 1.0, 'L1': 10.0, 'Perceptual': 10.0}

    # Define optimizers
    # optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9,weight_decay=5e-4)
    optimizer = optim.Adam(model.parameters(), lr=args.init_learning_rate)

    # Define lr scheduler
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.milestones,
                                               gamma=0.1)

    # best_acc = 0.0
    # start = time.time()
    cLoss = store_dict['loss']
    valLoss = store_dict['valLoss']
    valPSNR = store_dict['valPSNR']
    checkpoint_counter = 0

    loss_tracker = {}
    loss_tracker_test = {}

    psnr_old = 0.0
    dssim_old = 0.0

    for epoch in range(1, 10 *
                       args.epochs):  # loop over the dataset multiple times

        # Append and reset
        cLoss.append([])
        valLoss.append([])
        valPSNR.append([])
        running_loss = 0

        # Increment scheduler count
        scheduler.step()

        tqdm_loader = tqdm(range(len(train_loader)), ncols=150)

        loss = 0.0
        psnr_ = 0.0
        dssim_ = 0.0

        loss_tracker = {}
        for loss_fn in criterion.keys():
            loss_tracker[loss_fn] = 0.0

        # Train
        model.train(True)
        total_steps = 0.01
        total_steps_test = 0.01
        '''for train_idx, data in enumerate(train_loader, 1):
            loss = 0.0
            blur_data, sharpe_data = data
            #import pdb; pdb.set_trace()
            # input(sharpe_data.shape)
            #import pdb; pdb.set_trace()
            interp_idx = int(math.ceil((args.num_frame_blur/2) - 0.49))
            #input(interp_idx)
            if args.decode_mode == 'interp':
                sharpe_data = sharpe_data[:, :, 1::2, :, :]
            elif args.decode_mode == 'deblur':
                sharpe_data = sharpe_data[:, :, 0::2, :, :]
            else:
                #print('\nBoth\n')
                sharpe_data = sharpe_data

            #print(sharpe_data.shape)
            #input(blur_data.shape)
            blur_data = blur_data.to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)
            try:
                sharpe_data = sharpe_data.squeeze().to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)
            except:
                sharpe_data = sharpe_data.squeeze(3).to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)

            # clear gradient
            optimizer.zero_grad()

            # forward pass
            sharpe_out = model(blur_data)
            # import pdb; pdb.set_trace()
            # input(sharpe_out.shape)

            # compute losses
            # import pdb;
            # pdb.set_trace()
            sharpe_out = sharpe_out.permute(0, 2, 1, 3, 4)
            B, C, S, Fx, Fy = sharpe_out.shape
            for loss_fn in criterion.keys():
                loss_tmp = 0.0

                if loss_fn == 'Perceptual':
                    for bidx in range(B):
                        loss_tmp += criterion_w[loss_fn] * \
                                   criterion[loss_fn](sharpe_out[bidx].permute(1, 0, 2, 3),
                                                      sharpe_data[bidx].permute(1, 0, 2, 3)).sum()
                    # loss_tmp /= B
                else:
                    loss_tmp = criterion_w[loss_fn] * \
                               criterion[loss_fn](sharpe_out, sharpe_data)


                # try:
                # import pdb; pdb.set_trace()
                loss += loss_tmp # if
                # except :
                try:
                    loss_tracker[loss_fn] += loss_tmp.item()
                except KeyError:
                    loss_tracker[loss_fn] = loss_tmp.item()

            # Backpropagate
            loss.backward()
            optimizer.step()

            # statistics
            # import pdb; pdb.set_trace()
            sharpe_out = sharpe_out.detach().cpu().numpy()
            sharpe_data = sharpe_data.cpu().numpy()
            for sidx in range(S):
                for bidx in range(B):
                    psnr_ += psnr(sharpe_out[bidx, :, sidx, :, :], sharpe_data[bidx, :, sidx, :, :]) #, peak=1.0)
                    """dssim_ += dssim(np.moveaxis(sharpe_out[bidx, :, sidx, :, :], 0, 2),
                                    np.moveaxis(sharpe_data[bidx, :, sidx, :, :], 0, 2)
                                    )"""

            """sharpe_out = sharpe_out.reshape(-1,3, sx, sy).detach().cpu().numpy()
            sharpe_data = sharpe_data.reshape(-1, 3, sx, sy).cpu().numpy()
            for idx in range(sharpe_out.shape[0]):
                # import pdb; pdb.set_trace()
                psnr_ += psnr(sharpe_data[idx], sharpe_out[idx])
                dssim_ += dssim(np.swapaxes(sharpe_data[idx], 2, 0), np.swapaxes(sharpe_out[idx], 2, 0))"""

            # psnr_ /= sharpe_out.shape[0]
            # dssim_ /= sharpe_out.shape[0]
            running_loss += loss.item()
            loss_str = ''
            total_steps += B*S
            for key in loss_tracker.keys():
               loss_str += ' {0} : {1:6.4f} '.format(key, 1.0*loss_tracker[key] / total_steps)

            # set display info
            if train_idx % 5 == 0:
                tqdm_loader.set_description(('\r[Training] [Ep {0:6d}] loss: {1:6.4f} PSNR: {2:6.4f} SSIM: {3:6.4f} '.format
                                    (epoch, running_loss / total_steps,
                                     psnr_ / total_steps,
                                     dssim_ / total_steps) + loss_str
                                    ))

                tqdm_loader.update(5)
        tqdm_loader.close()'''

        # Validation
        running_loss_test = 0.0
        psnr_test = 0.0
        dssim_test = 0.0
        # print('len', len(test_loader))
        tqdm_loader_test = tqdm(range(len(test_loader)), ncols=150)
        # import pdb; pdb.set_trace()

        loss_tracker_test = {}
        for loss_fn in criterion.keys():
            loss_tracker_test[loss_fn] = 0.0

        with torch.no_grad():
            model.eval()
            total_steps_test = 0.0

            for test_idx, data in enumerate(test_loader, 1):
                loss = 0.0
                blur_data, sharpe_data = data
                interp_idx = int(math.ceil((args.num_frame_blur / 2) - 0.49))
                # input(interp_idx)
                if args.decode_mode == 'interp':
                    sharpe_data = sharpe_data[:, :, 1::2, :, :]
                elif args.decode_mode == 'deblur':
                    sharpe_data = sharpe_data[:, :, 0::2, :, :]
                else:
                    # print('\nBoth\n')
                    sharpe_data = sharpe_data

                # print(sharpe_data.shape)
                # input(blur_data.shape)
                blur_data = blur_data.to(device)[:, :, :, :352, :].permute(
                    0, 1, 2, 4, 3)
                try:
                    sharpe_data = sharpe_data.squeeze().to(
                        device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)
                except:
                    sharpe_data = sharpe_data.squeeze(3).to(
                        device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)

                # clear gradient
                optimizer.zero_grad()

                # forward pass
                sharpe_out = model(blur_data)
                # import pdb; pdb.set_trace()
                # input(sharpe_out.shape)

                # compute losses
                sharpe_out = sharpe_out.permute(0, 2, 1, 3, 4)
                B, C, S, Fx, Fy = sharpe_out.shape
                for loss_fn in criterion.keys():
                    loss_tmp = 0.0
                    if loss_fn == 'Perceptual':
                        for bidx in range(B):
                            loss_tmp += criterion_w[loss_fn] * \
                                        criterion[loss_fn](sharpe_out[bidx].permute(1, 0, 2, 3),
                                                           sharpe_data[bidx].permute(1, 0, 2, 3)).sum()
                        # loss_tmp /= B
                    else:
                        loss_tmp = criterion_w[loss_fn] * \
                                   criterion[loss_fn](sharpe_out, sharpe_data)
                    loss += loss_tmp
                    try:
                        loss_tracker_test[loss_fn] += loss_tmp.item()
                    except KeyError:
                        loss_tracker_test[loss_fn] = loss_tmp.item()

                if ((test_idx % args.progress_iter) == args.progress_iter - 1):
                    itr = test_idx + epoch * len(test_loader)
                    # itr_train
                    writer.add_scalars(
                        'Loss', {
                            'trainLoss': running_loss / total_steps,
                            'validationLoss':
                            running_loss_test / total_steps_test
                        }, itr)
                    writer.add_scalar('Train PSNR', psnr_ / total_steps, itr)
                    writer.add_scalar('Test PSNR',
                                      psnr_test / total_steps_test, itr)
                    # import pdb; pdb.set_trace()
                    # writer.add_image('Validation', sharpe_out.permute(0, 2, 3, 1), itr)

                # statistics
                sharpe_out = sharpe_out.detach().cpu().numpy()
                sharpe_data = sharpe_data.cpu().numpy()
                for sidx in range(S):
                    for bidx in range(B):
                        psnr_test += psnr(
                            sharpe_out[bidx, :, sidx, :, :],
                            sharpe_data[bidx, :, sidx, :, :])  #, peak=1.0)
                        dssim_test += dssim(
                            np.moveaxis(sharpe_out[bidx, :, sidx, :, :], 0, 2),
                            np.moveaxis(sharpe_data[bidx, :, sidx, :, :], 0,
                                        2))  #,range=1.0  )

                running_loss_test += loss.item()
                total_steps_test += B * S
                loss_str = ''
                for key in loss_tracker.keys():
                    loss_str += ' {0} : {1:6.4f} '.format(
                        key, 1.0 * loss_tracker_test[key] / total_steps_test)

                # set display info

                tqdm_loader_test.set_description((
                    '\r[Test    ] [Ep {0:6d}] loss: {1:6.4f} PSNR: {2:6.4f} SSIM: {3:6.4f} '
                    .format(epoch, running_loss_test / total_steps_test,
                            psnr_test / total_steps_test,
                            dssim_test / total_steps_test) + loss_str))
                tqdm_loader_test.update(1)
            tqdm_loader_test.close()

        # save model
        if psnr_old < (psnr_test / total_steps_test):
            if epoch != 1:
                os.remove(
                    os.path.join(
                        args.checkpoint_dir,
                        'epoch-{}-test-psnr-{}-ssim-{}.ckpt'.format(
                            epoch_old,
                            str(round(psnr_old, 4)).replace('.', 'pt'),
                            str(round(dssim_old, 4)).replace('.', 'pt'))))
            epoch_old = epoch
            psnr_old = psnr_test / total_steps_test
            dssim_old = dssim_test / total_steps_test

            checkpoint_dict = {
                'epoch': epoch_old,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'train_psnr': psnr_ / total_steps,
                'train_dssim': dssim_ / total_steps,
                'train_mse': loss_tracker['MSE'] / total_steps,
                'train_l1': loss_tracker['L1'] / total_steps,
                # 'train_percp': loss_tracker['Perceptual'] / total_steps,
                'test_psnr': psnr_old,
                'test_dssim': dssim_old,
                'test_mse': loss_tracker_test['MSE'] / total_steps_test,
                'test_l1': loss_tracker_test['L1'] / total_steps_test,
                # 'test_percp': loss_tracker_test['Perceptual'] / total_steps_test,
            }

            torch.save(
                checkpoint_dict,
                os.path.join(
                    args.checkpoint_dir,
                    'epoch-{}-test-psnr-{}-ssim-{}.ckpt'.format(
                        epoch_old,
                        str(round(psnr_old, 4)).replace('.', 'pt'),
                        str(round(dssim_old, 4)).replace('.', 'pt'))))

        # if epoch % args.checkpoint_epoch == 0:
        #    torch.save(model.state_dict(),args.checkpoint_dir + str(int(epoch/100))+".ckpt")

    return None
Exemplo n.º 13
0
def train(args):
    dataset = open("dataset.csv", "r").readlines()
    train_set = dataset[:600]
    val_set = dataset[600:]
    root_dir = root_dir = "data/Lung_Segmentation/"

    train_data = LungSegmentationDataGen(train_set, root_dir, args)
    val_data = LungSegmentationDataGen(val_set, root_dir, args)

    train_dataloader = DataLoader(train_data,
                                  batch_size=5,
                                  shuffle=True,
                                  num_workers=4)

    val_dataloader = DataLoader(val_data,
                                batch_size=5,
                                shuffle=True,
                                num_workers=4)

    dataloaders = {"train": train_dataloader, "val": val_dataloader}

    dataset_sizes = {"train": len(train_set), "val": len(val_set)}

    print("dataset_sizes: {}".format(dataset_sizes))

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = UNet(in_channels=1)
    model = model.to(device)

    optimizer = optim.Adam(model.parameters())

    loss_train = []
    loss_valid = []

    current_mean_dsc = 0.0
    best_validation_dsc = 0.0

    epochs = args.epochs
    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-' * 10)
        dice_score_list = []
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            # Iterate over data.
            for i, data in enumerate(dataloaders[phase]):
                inputs, y_true = data
                inputs = inputs.to(device)
                y_true = y_true.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    # forward pass with batch input
                    y_pred = model(inputs)

                    loss = dice_loss(y_true, y_pred)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        # print("step: {}, train_loss: {}".format(i, loss))
                        loss_train.append(loss.item())

                        # calculate the gradients based on loss
                        loss.backward()

                        # update the weights
                        optimizer.step()

                    if phase == "val":
                        loss_valid.append(loss.item())
                        dsc = dice_score(y_true, y_pred)
                        print("step: {}, val_loss: {}, val dice_score: {}".
                              format(i, loss, dsc))
                        dice_score_list.append(dsc.detach().numpy())

                if phase == "train" and (i + 1) % 10 == 0:
                    print("step:{}, train_loss: {}".format(
                        i + 1, np.mean(loss_train)))
                    loss_train = []
            if phase == "val":
                print("mean val_loss: {}".format(np.mean(loss_valid)))
                loss_valid = []
                current_mean_dsc = np.mean(dice_score_list)
                print("validation set dice_score: {}".format(current_mean_dsc))
                if current_mean_dsc > best_validation_dsc:
                    best_validation_dsc = current_mean_dsc
                    print("best dice_score on val set: {}".format(
                        best_validation_dsc))
                    model_name = "unet_{0:.2f}.pt".format(best_validation_dsc)
                    torch.save(model.state_dict(),
                               os.path.join(args.weights, model_name))
Exemplo n.º 14
0
class Trainer(object):
    """Trainer for training and testing the model"""
    def __init__(self, data_loader, config):
        """Initialize configurations"""

        # model configuration
        self.in_dim = config.in_dim
        self.out_dim = config.out_dim
        self.num_filters = config.num_filters
        self.patch_size = config.patch_size

        # training configuration
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.lr = config.lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.weight_decay = config.weight_decay
        self.resume_iters = config.resume_iters
        self.mode = config.mode

        # miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:{}'.format(config.device_id) \
                                   if self.use_cuda else 'cpu')

        # training result configuration
        self.log_dir = config.log_dir
        self.log_step = config.log_step
        self.model_save_dir = config.model_save_dir
        self.model_save_step = config.model_save_step

        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

        # data loader
        if self.mode == 'train' or self.mode == 'test':
            self.data_loader = data_loader
        else:
            self.train_data_loader, self.test_data_loader = data_loader

    def build_model(self):
        """Create a model"""
        self.model = UNet(self.in_dim, self.out_dim, self.num_filters)
        self.model = self.model.float()
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          self.lr, [self.beta1, self.beta2],
                                          weight_decay=self.weight_decay)
        self.print_network(self.model, 'unet')
        self.model.to(self.device)

    def _load(self, checkpoint_path):
        if self.use_cuda:
            checkpoint = torch.load(checkpoint_path)
        else:
            checkpoint = torch.load(checkpoint_path,
                                    map_location=lambda storage, loc: storage)
        return checkpoint

    def restore_model(self, resume_iters):
        """Restore the trained model"""

        print(
            'Loading the trained models from step {}...'.format(resume_iters))
        model_path = os.path.join(self.model_save_dir,
                                  '{}-unet'.format(resume_iters) + '.ckpt')
        checkpoint = self._load(model_path)
        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])

    def print_network(self, model, name):
        """Print out the network information"""

        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        #print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def print_optimizer(self, opt, name):
        """Print out optimizer information"""

        print(opt)
        print(name)

    def build_tensorboard(self):
        """Build tensorboard for visualization"""

        from logger import Logger
        self.logger = Logger(self.log_dir)

    def reset_grad(self):
        """Reset the gradient buffers."""

        self.optimizer.zero_grad()

    def train(self):
        """Train model"""
        if self.mode != 'train_test':
            data_loader = self.data_loader
        else:
            data_loader = self.train_data_loader

        print("current dataset size: ", len(data_loader))
        data_iter = iter(data_loader)

        if not os.path.exists(self.model_save_dir):
            os.makedirs(self.model_save_dir)

        # start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            print('Resuming ...')
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)
            self.print_optimizer(self.optimizer, 'optimizer')

        # print learning rate information
        lr = self.lr
        print('Current learning rates, g_lr: {}.'.format(lr))

        # start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):

            # fetch batch data
            try:
                in_data, label = next(data_iter)
            except:
                data_iter = iter(data_loader)
                in_data, label, _, _, _ = next(data_iter)

            in_data = in_data.float().to(self.device)
            label = label.to(self.device)

            # train the model
            self.model = self.model.train()
            y_out = self.model(in_data)
            loss = nn.BCEWithLogitsLoss()
            output = loss(y_out, label)
            self.reset_grad()
            output.backward()
            self.optimizer.step()

            # logging
            if (i + 1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(
                    et, i + 1, self.num_iters)
                log += ", {}: {:.4f}".format("loss", output.mean().item())
                print(log)

                if self.use_tensorboard:
                    self.logger.scalar_summary("loss",
                                               output.mean().item(), i + 1)

            # save model checkpoints
            if (i + 1) % self.model_save_step == 0:
                path = os.path.join(self.model_save_dir,
                                    '{}-unet'.format(i + 1) + '.ckpt')
                torch.save(
                    {
                        'model': self.model.state_dict(),
                        'optimizer': self.optimizer.state_dict()
                    }, path)
                print('Saved model checkpoints into {}...'.format(
                    self.model_save_dir))

    def test(self):
        """Test model"""

        if self.mode != 'train_test':
            data_loader = self.data_loader
        else:
            data_loader = self.test_data_loader
        print("current dataset size: ", len(data_loader))
        data_iter = iter(data_loader)

        # start testing on trained model
        if self.resume_iters and self.mode != 'train_test':
            print('Resuming ...')
            self.restore_model(self.resume_iters)

        # start testing.
        result, trace = np.zeros((78, 110, 24)), np.zeros((78, 110, 24))
        print('Start testing...')
        correct, total, bcorrect = 0, 0, 0
        while (True):

            # fetch batch data
            try:
                data_in, label, i, j, k = next(data_iter)
            except:
                break

            data_in = data_in.float().to(self.device)
            label = label.float().to(self.device)

            # test the model
            self.model = self.model.eval()
            y_hat = self.model(data_in)
            m = nn.Sigmoid()
            y_hat = m(y_hat)
            y_hat = y_hat.squeeze().detach().cpu().numpy()

            label = label.cpu().numpy().astype(int)
            y_hat_th = (y_hat > 0.2)
            label = (label > 0.5)
            test = (label == y_hat_th)
            correct += np.sum(test)
            btest = (label == 0)
            bcorrect += np.sum(btest)
            total += y_hat_th.size

            radius = int(self.patch_size / 2)
            for step in range(self.batch_size):
                x, y, z, pred = i[step], j[step], k[step], np.squeeze(
                    y_hat_th[step, :, :, :])
                result[x - radius:x + radius, y - radius:y + radius,
                       z - radius:z + radius] += pred
                trace[x - radius:x + radius, y - radius:y + radius,
                      z - radius:z + radius] += np.ones(
                          (self.patch_size, self.patch_size, self.patch_size))

        print('Accuracy: %.3f%%' % (correct / total * 100))
        print('Baseline Accuracy: %.3f%%' % (bcorrect / total * 100))

        trace += (trace == 0)
        result = result / trace
        scipy.io.savemat('prediction.mat', {'result': result})

    def train_test(self):
        """Train and test model"""

        self.train()
        self.test()
Exemplo n.º 15
0
def trainUnet(dirP, name, setLen, epochs=20):
    class WPCDEDataset(Dataset):
        def __init__(self, lenG, root_dir, transform=None):

            self.root_dir = root_dir
            self.lenG = lenG
            self.transform = transform

        def __len__(self):

            return self.lenG

        def __getitem__(self, idx):
            img_nameX = self.root_dir + '%dx.jpg' % (idx)
            img_nameY = self.root_dir + '%dy.jpg' % (idx)
            imageX = Image.open(img_nameX).convert('RGB')
            imageY = Image.open(img_nameY).convert('RGB')

            if self.transform:
                imageX = self.transform(imageX)
                imageY = self.transform(imageY)
            return imageX, imageY

    transf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    trainDatas = WPCDEDataset(setLen, dirP + r'\\' + 'train', transform=transf)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    batch_size = 16
    lr = 2e-4
    weight_decay = 0
    start_epoch = 0
    outf = r"..\model"

    unet = UNet(in_channels=3, out_channels=3)
    unet.to(device)
    optimizer = optim.Adam(list(unet.parameters()),
                           lr=lr,
                           weight_decay=weight_decay)

    dataloaderT = torch.utils.data.DataLoader(trainDatas,
                                              batch_size=batch_size,
                                              shuffle=True,
                                              num_workers=int(0))

    # dataloaderV = torch.utils.data.DataLoader(valDatas, batch_size=batch_size,
    #                                          shuffle=True, num_workers=int(0))

    dataSplit = None  # reserved rate of  data

    for epoch in range(start_epoch, start_epoch + epochs):
        unet.train()
        for i, (x, y) in enumerate(dataloaderT):
            if dataSplit is not None:
                if i > len(dataloaderT) * dataSplit:
                    break
            x = x.to(device)
            y = y.to(device)

            optimizer.zero_grad()
            ypred = unet(x)
            loss = F.mse_loss(y, ypred)
            loss.backward()
            optimizer.step()
            # break
            if (i) % int(len(dataloaderT) / 4) == 0:
                print('[%d/%d][%d/%d]\tLoss: %.4f\t ' %
                      (epoch, start_epoch + epochs, i, len(dataloaderT), loss))
        state = {
            'model': unet.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch
        }
        torch.save(state,
                   '%s/UnetS%d%sepoch%d.pth' % (outf, setLen, name, epoch))
Exemplo n.º 16
0
    out = model(torch.from_numpy(img)).cpu().detach().numpy()
    return np.squeeze(out, 0)[0]


app = Flask(__name__)
UPLOAD_FOLDER = "/home/atom/projects/data_science_bowl_2018/src/static"
PRED_PATH = "/home/atom/projects/data_science_bowl_2018/src/static/"
DEVICE = "cpu"


@app.route("/", methods=['GET', 'POST'])
def upload_predict():
    if request.method == "POST":
        image_file = request.files['image']
        if image_file:
            image_location = os.path.join(UPLOAD_FOLDER, image_file.filename)
            image_file.save(image_location)
            pred = predict(image_location, MODEL)
            imsave(PRED_PATH + "pred.png", pred)
            return render_template("index.html",
                                   image_loc=image_file.filename,
                                   pred_loc="pred.png")
    return render_template("index.html", prediction=0, image_loc=None)


if __name__ == "__main__":
    MODEL = UNet()
    MODEL.load_state_dict(torch.load(config.MODEL_LOAD_PATH))
    MODEL.to(DEVICE)
    MODEL.eval()
    app.run(port=12000, debug=True)
Exemplo n.º 17
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_path',
                        type=list,
                        default=['../../results/model'])
    parser.add_argument('--out_path', type=str, default='.')
    parser.add_argument('--NUMBER_OF_IMAGES', type=int, default=5000)
    parser.add_argument('--NUMBER_OF_PLOTS', type=int, default=5)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--KERNEL_LVL', type=float, default=3)
    parser.add_argument('--NOISE_LVL', type=float, default=1)
    parser.add_argument('--MOTION_BLUR', type=bool, default=True)
    parser.add_argument('--HOMO_ALIGN', type=bool, default=True)
    parser.add_argument('--model_iter', type=int, default=None)
    args = parser.parse_args()

    print()
    print(args)
    print()

    # Evaluation metric parameters
    SSIM_window_size = 3

    dict_ = {}
    for e, exp_path in enumerate(args.exp_paths):

        if args.model_iter == None:
            model_path = get_newest_model(exp_path)
        else:
            model_path = os.path.join(exp_path, args.model_iter)

        model_name = os.path.split(model_path)[1]
        name = str(e) + '_' + model_name.replace('.pt', '')

        dict_[name] = {}
        if not os.path.isdir((os.path.join(args.output_path, name))):
            os.mkdir(os.path.join(args.output_path, name))

        model = UNet(in_channel=3, out_channel=3)

        model.load_state_dict(torch.load(model_path))
        model.eval()

        use_cuda = torch.cuda.is_available()
        device = torch.device("cuda:0" if use_cuda else "cpu")

        model = model.to(device)

        # Parameters
        params = {'batch_size': 1, 'shuffle': True, 'num_workers': 0}

        random.seed(42)
        np.random.seed(42)
        torch.manual_seed(42)

        # Generators
        data_set = Dataset('../../data/test/',
                           max_images=args.NUMBER_OF_IMAGES,
                           kernel_lvl=args.KERNEL_LVL,
                           noise_lvl=args.NOISE_LVL,
                           motion_blur_boolean=args.MOTION_BLUR,
                           homo_align=args.HOMO_ALIGN)
        data_gen = data.DataLoader(data_set, **params)

        # evaluation
        evaluationData = {}

        for i, (X_batch, y_labels) in enumerate(data_gen):
            # Alter the burst length for each mini batch

            burst_length = np.random.randint(
                2,
                9,
            )
            X_batch = X_batch[:, :burst_length, :, :, :]

            # Transfer to GPU
            X_batch, y_labels = X_batch.to(device).type(
                torch.float), y_labels.to(device).type(torch.float)

            with torch.set_grad_enabled(False):
                model.eval()
                pred_batch = model(X_batch)

            evaluationData[str(i)] = {}
            for j in range(params['batch_size']):
                evaluationData[str(i)][str(j)] = {}

                y_label = y_labels[j, :, :, :].detach().cpu().numpy().astype(
                    int)
                pred = pred_batch[j, :, :, :].detach().cpu().numpy().astype(
                    int)

                y_label = np.transpose(y_label, (1, 2, 0))
                pred = np.transpose(pred, (1, 2, 0))
                pred = np.clip(pred, 0, 255)

                if i < args.NUMBER_OF_PLOTS and j == 0:
                    plt.figure(figsize=(20, 5))
                    plt.subplot(1, 2 + len(X_batch[j, :, :, :, :]), 1)
                    plt.imshow(y_label)
                    plt.axis('off')
                    plt.axis('off')
                    plt.title('GT')

                    plt.subplot(1, 2 + len(X_batch[j, :, :, :, :]), 2)
                    plt.imshow(pred)
                    plt.axis('off')
                    plt.title('Pred')

                burst_ssim = []
                burst_psnr = []
                for k in range(len(X_batch[j, :, :, :, :])):
                    x = X_batch[j,
                                k, :, :, :].detach().cpu().numpy().astype(int)
                    burst = np.transpose(x, (1, 2, 0))

                    if i < args.NUMBER_OF_PLOTS and j == 0:
                        plt.subplot(1, 2 + len(X_batch[j, :, :, :, :]), 3 + k)
                        plt.imshow(burst)
                        plt.axis('off')
                        plt.title('Burst ' + str(k))

                    burst_ssim.append(
                        ssim(y_label.astype(float),
                             burst.astype(float),
                             multichannel=True,
                             win_size=SSIM_window_size))
                    burst_psnr.append(psnr(y_label, burst))

                SSIM = ssim(pred.astype(float),
                            y_label.astype(float),
                            multichannel=True,
                            win_size=SSIM_window_size)
                PSNR = psnr(pred, y_label)
                if i < args.NUMBER_OF_PLOTS and j == 0:
                    plt.savefig(os.path.join(args.output_path, name,
                                             str(i) + '.png'),
                                bbox_inches='tight',
                                pad_inches=0)
                    plt.cla()
                    plt.clf()
                    plt.close()

                evaluationData[str(i)][str(j)]['SSIM'] = SSIM
                evaluationData[str(i)][str(j)]['PSNR'] = PSNR
                evaluationData[str(i)][str(j)]['length'] = burst_length
                evaluationData[str(i)][str(j)]['SSIM_burst'] = burst_ssim
                evaluationData[str(i)][str(j)]['PSNR_burst'] = burst_psnr

            if i % 500 == 0 and i > 0:
                print(i)

        #######
        # Save Results
        #######

        x_ssim, y_ssim, y_max_ssim = [], [], []
        x_psnr, y_psnr, y_max_psnr = [], [], []

        for i in evaluationData:
            for j in evaluationData[i]:
                x_ssim.append(evaluationData[i][j]['length'])
                y_ssim.append(evaluationData[i][j]['SSIM'])
                y_max_ssim.append(evaluationData[i][j]['SSIM'] -
                                  max(evaluationData[i][j]['SSIM_burst']))

                x_psnr.append(evaluationData[i][j]['length'])
                y_psnr.append(evaluationData[i][j]['PSNR'])
                y_max_psnr.append(evaluationData[i][j]['PSNR'] -
                                  max(evaluationData[i][j]['PSNR_burst']))

        method = [name] * len(x_ssim)
        dict_[name]['ssim'] = pd.DataFrame(
            np.transpose([x_ssim, y_ssim, y_max_ssim, method]),
            columns=['burst_length', 'ssim', 'max_pred_ssim', 'method'])
        dict_[name]['psnr'] = pd.DataFrame(
            np.transpose([x_psnr, y_psnr, y_max_psnr, method]),
            columns=['burst_length', 'psnr', 'max_pred_psnr', 'method'])

        dict_[name]['ssim'].to_csv(
            os.path.join(args.output_path, 'ssim_' + name + '.csv'))
        dict_[name]['psnr'].to_csv(
            os.path.join(args.output_path, 'psnr_' + name + '.csv'))

def generate_sample(model):
    betas = make_beta_schedule('linear', 1e-4, 2e-2, 1000)
    diffusion = GaussianDiffusion(betas).to('cuda')

    imgs = p_sample_loop(diffusion, model, [16, 3, 128, 128], 'cuda', capture_every=10)
    imgs = imgs[1:]

    id = 0
    grid = make_grid(torch.cat([i[id:id + 1] for i in imgs[:-1:4]], 0), nrow=5, normalize=True, range=(-1, 1))

    return grid.detach().mul(255).cpu().type(torch.uint8).permute(1, 2, 0).numpy()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument("checkpoint", type=str, help="Path to checkpoint to load")

    args = parser.parse_args()

    ckpt = torch.load(args.checkpoint)
    model = UNet(3, 128, [1, 1, 2, 2, 4, 4], 2, [16], 0, 1)
    model.load_state_dict(ckpt['ema'])
    model = model.to('cuda')

    grid_output = generate_sample(model)

    Image.fromarray(grid_output)
Exemplo n.º 19
0
class Trainer():

	def __init__(self,config,trainLoader,validLoader):
		
		self.config = config
		self.trainLoader = trainLoader
		self.validLoader = validLoader
		

		self.numTrain = len(self.trainLoader.dataset)
		self.numValid = len(self.validLoader.dataset)
		
		self.saveModelDir = str(self.config.save_model_dir)+"/"
		
		self.bestModel = config.bestModel
		self.useGpu = self.config.use_gpu


		self.net = UNet()


		if(self.config.resume == True):
			print("LOADING SAVED MODEL")
			self.loadCheckpoint()

		else:
			print("INTIALIZING NEW MODEL")

		self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
		self.net = self.net.to(self.device)

	

		self.totalEpochs = config.epochs
		

		self.optimizer = optim.Adam(self.net.parameters(), lr=5e-4)
		self.loss = DiceLoss()

		self.num_params = sum([p.data.nelement() for p in self.net.parameters()])
		
		self.trainPaitence = config.train_paitence
		

		if not self.config.resume:																																																																																																																																																																																		# self.freezeLayers(6)
			summary(self.net, input_size=(3,256,256))
			print('[*] Number of model parameters: {:,}'.format(self.num_params))
			self.writer = SummaryWriter(self.config.tensorboard_path+"/")

		
		
		

	def train(self):
		bestIOU = 0

		print("\n[*] Train on {} sample pairs, validate on {} trials".format(
			self.numTrain, self.numValid))
		

		for epoch in range(0,self.totalEpochs):
			print('\nEpoch: {}/{}'.format(epoch+1, self.totalEpochs))
			
			self.trainOneEpoch(epoch)

			validationIOU = self.validationTest(epoch)

			print("VALIDATION IOU: ",validationIOU)

			# check for improvement
			if(validationIOU > bestIOU):
				print("COUNT RESET !!!")
				bestIOU=validationIOU
				self.counter = 0
				self.saveCheckPoint(
				{
					'epoch': epoch + 1,
					'model_state': self.net.state_dict(),
					'optim_state': self.optimizer.state_dict(),
					'best_valid_acc': bestIOU,
				},True)

			else:
				self.counter += 1
				
			
			if self.counter > self.trainPaitence:
				self.saveCheckPoint(
				{
					'epoch': epoch + 1,
					'model_state': self.net.state_dict(),
					'optim_state': self.optimizer.state_dict(),
					'best_valid_acc': validationIOU,
				},False)
				print("[!] No improvement in a while, stopping training...")
				print("BEST VALIDATION IOU: ",bestIOU)

				return None

		
	def trainOneEpoch(self,epoch):
		self.net.train()
		train_loss = 0
		total_IOU = 0
		
		for batch_idx, (images,targets) in enumerate(self.trainLoader):


			images = images.to(self.device)
			targets = targets.to(self.device)

			
	
			self.optimizer.zero_grad()

			outputMaps = self.net(images)
			
			loss = self.loss(outputMaps,targets)
			

			
			loss.backward()
			self.optimizer.step()

			train_loss += loss.item()

			current_IOU = calc_IOU(outputMaps,targets)
			total_IOU += current_IOU
			
			del(images)
			del(targets)

			progress_bar(batch_idx, len(self.trainLoader), 'Loss: %.3f | IOU: %.3f'
		% (train_loss/(batch_idx+1), current_IOU))
		self.writer.add_scalar('Train/Loss', train_loss/batch_idx+1, epoch)
		self.writer.add_scalar('Train/IOU', total_IOU/batch_idx+1, epoch)
		
		


	def validationTest(self,epoch):
		self.net.eval()
		validationLoss = []
		total_IOU = []
		with torch.no_grad():
			for batch_idx, (images,targets) in enumerate(self.validLoader):
				
				
				
				images = images.to(self.device)
				targets = targets.to(self.device)


				outputMaps = self.net(images)

				loss = self.loss(outputMaps,targets)


				currentValidationLoss = loss.item()
				validationLoss.append(currentValidationLoss)
				current_IOU = calc_IOU(outputMaps,targets)
				total_IOU.append(current_IOU)

			
				# progress_bar(batch_idx, len(self.validLoader), 'Loss: %.3f | IOU: %.3f' % (currentValidationLoss), current_IOU)


				del(images)
				del(targets)

		meanIOU = np.mean(total_IOU)
		meanValidationLoss = np.mean(validationLoss)
		self.writer.add_scalar('Validation/Loss', meanValidationLoss, epoch)
		self.writer.add_scalar('Validation/IOU', meanIOU, epoch)
		
		print("VALIDATION LOSS: ",meanValidationLoss)
				
		
		return meanIOU



	def test(self,dataLoader):

		self.net.eval()
		testLoss = []
		total_IOU = []

		total_outputs_maps = []
		total_input_images = []
		
		with torch.no_grad():
			for batch_idx, (images,targets) in enumerate(dataLoader):

				images = images.to(self.device)
				targets = targets.to(self.device)


				outputMaps = self.net(images)

				
				loss = self.loss(outputMaps,targets)

				testLoss.append(loss.item())
				current_IOU = calc_IOU(outputMaps,targets)
				
				total_IOU.append(current_IOU)
				
				total_outputs_maps.append(outputMaps.cpu().detach().numpy())


				# total_input_images.append(transforms.ToPILImage()(images))
				
				total_input_images.append(images.cpu().detach().numpy())

				del(images)
				del(targets)
				break

		meanIOU = np.mean(total_IOU)
		meanLoss = np.mean(testLoss)
		print("TEST IOU: ",meanIOU)
		print("TEST LOSS: ",meanLoss)	

		return total_input_images,total_outputs_maps
		

		
	def saveCheckPoint(self,state,isBest):
		filename = "model.pth"
		ckpt_path = os.path.join(self.saveModelDir, filename)
		torch.save(state, ckpt_path)
		
		if isBest:
			filename = "best_model.pth"
			shutil.copyfile(ckpt_path, os.path.join(self.saveModelDir, filename))

	def loadCheckpoint(self):

		print("[*] Loading model from {}".format(self.saveModelDir))
		if(self.bestModel):
			print("LOADING BEST MODEL")

			filename = "best_model.pth"

		else:
			filename = "model.pth"

		ckpt_path = os.path.join(self.saveModelDir, filename)
		print(ckpt_path)
		
		if(self.useGpu==False):
			self.net=torch.load(ckpt_path, map_location=lambda storage, loc: storage)


			

		else:
			print("*"*40+" LOADING MODEL FROM GPU "+"*"*40)
			self.ckpt = torch.load(ckpt_path)
			self.net.load_state_dict(self.ckpt['model_state'])

			self.net.cuda()
Exemplo n.º 20
0
Arquivo: main.py Projeto: Kwrleon/Zero
transform = transforms.Compose([transforms.Resize(160),
                                transforms.ToTensor()])

train_name = np.load(os.path.join(path,"train_names.npy"))
valid_name = np.load(os.path.join(path,"valid_names.npy"))

train_set = Dataset(path = path, data_name = train_name, transform = transform)
valid_set = Dataset(path = path, data_name = valid_name, transform = transform)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=v_batch_size, shuffle=False)

#载入模型
model = model()
#summary(model, (1,400,640),device='cpu')
model = model.to(device)
#model = torch.nn.DataParallel(model)#GPU并行运算

#损失函数
criterion = torch.nn.BCELoss().to(device)
#定义优化器
optimizer = torch.optim.SGD(model.parameters(),
                            lr = lr,
                            momentum=0.9)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
                            optimizer,
                            milestones = [8,14,19,25,30,35],
                            gamma = 0.1)
print("Optimizer: ", optimizer.__class__.__name__)

#开始训练
class Instructor:
    ''' Model training and evaluation '''
    def __init__(self, opt):
        self.opt = opt
        if opt.inference:
            self.testset = TestImageDataset(fdir=opt.impaths['test'],
                                            imsize=opt.imsize)
        else:
            self.trainset = ImageDataset(fdir=opt.impaths['train'],
                                         bdir=opt.impaths['btrain'],
                                         imsize=opt.imsize,
                                         mode='train',
                                         aug_prob=opt.aug_prob,
                                         prefetch=opt.prefetch)
            self.valset = ImageDataset(fdir=opt.impaths['val'],
                                       bdir=opt.impaths['bval'],
                                       imsize=opt.imsize,
                                       mode='val',
                                       aug_prob=opt.aug_prob,
                                       prefetch=opt.prefetch)
        self.model = UNet(n_channels=3,
                          n_classes=1,
                          bilinear=self.opt.use_bilinear)
        if opt.checkpoint:
            self.model.load_state_dict(
                torch.load('./state_dict/{:s}'.format(opt.checkpoint),
                           map_location=self.opt.device))
            print('checkpoint {:s} has been loaded'.format(opt.checkpoint))
        if opt.multi_gpu == 'on':
            self.model = torch.nn.DataParallel(self.model)
        self.model = self.model.to(opt.device)
        self._print_args()

    def _print_args(self):
        n_trainable_params, n_nontrainable_params = 0, 0
        for p in self.model.parameters():
            n_params = torch.prod(torch.tensor(p.shape))
            if p.requires_grad:
                n_trainable_params += n_params
            else:
                n_nontrainable_params += n_params
        self.info = 'n_trainable_params: {0}, n_nontrainable_params: {1}\n'.format(
            n_trainable_params, n_nontrainable_params)
        self.info += 'training arguments:\n' + '\n'.join([
            '>>> {0}: {1}'.format(arg, getattr(self.opt, arg))
            for arg in vars(self.opt)
        ])
        if self.opt.device.type == 'cuda':
            print('cuda memory allocated:',
                  torch.cuda.memory_allocated(opt.device.index))
        print(self.info)

    def _reset_records(self):
        self.records = {
            'best_epoch': 0,
            'best_dice': 0,
            'train_loss': list(),
            'val_loss': list(),
            'val_dice': list(),
            'checkpoints': list()
        }

    def _update_records(self, epoch, train_loss, val_loss, val_dice):
        if val_dice > self.records['best_dice']:
            path = './state_dict/{:s}_dice{:.4f}_temp{:s}.pt'.format(
                self.opt.model_name, val_dice,
                str(time.time())[-6:])
            if self.opt.multi_gpu == 'on':
                torch.save(self.model.module.state_dict(), path)
            else:
                torch.save(self.model.state_dict(), path)
            self.records['best_epoch'] = epoch
            self.records['best_dice'] = val_dice
            self.records['checkpoints'].append(path)
        self.records['train_loss'].append(train_loss)
        self.records['val_loss'].append(val_loss)
        self.records['val_dice'].append(val_dice)

    def _draw_records(self):
        timestamp = str(int(time.time()))
        print('best epoch: {:d}'.format(self.records['best_epoch']))
        print('best train loss: {:.4f}, best val loss: {:.4f}'.format(
            min(self.records['train_loss']), min(self.records['val_loss'])))
        print('best val dice {:.4f}'.format(self.records['best_dice']))
        os.rename(
            self.records['checkpoints'][-1],
            './state_dict/{:s}_dice{:.4f}_save{:s}.pt'.format(
                self.opt.model_name, self.records['best_dice'], timestamp))
        for path in self.records['checkpoints'][0:-1]:
            os.remove(path)
        # Draw figures
        plt.figure()
        trainloss, = plt.plot(self.records['train_loss'])
        valloss, = plt.plot(self.records['val_loss'])
        plt.legend([trainloss, valloss], ['train', 'val'], loc='upper right')
        plt.title('{:s} loss curve'.format(timestamp))
        plt.savefig('./figs/{:s}_loss.png'.format(timestamp),
                    format='png',
                    transparent=True,
                    dpi=300)
        plt.figure()
        valdice, = plt.plot(self.records['val_dice'])
        plt.title('{:s} dice curve'.format(timestamp))
        plt.savefig('./figs/{:s}_dice.png'.format(timestamp),
                    format='png',
                    transparent=True,
                    dpi=300)
        # Save report
        report = '\t'.join(
            ['val_dice', 'train_loss', 'val_loss', 'best_epoch', 'timestamp'])
        report += "\n{:.4f}\t{:.4f}\t{:.4f}\t{:d}\t{:s}\n{:s}".format(
            self.records['best_dice'], min(self.records['train_loss']),
            min(self.records['val_loss']), self.records['best_epoch'],
            timestamp, self.info)
        with open('./logs/{:s}_log.txt'.format(timestamp), 'w') as f:
            f.write(report)
        print('report saved:', './logs/{:s}_log.txt'.format(timestamp))

    def _train(self, train_dataloader, criterion, optimizer):
        self.model.train()
        train_loss, n_total, n_batch = 0, 0, len(train_dataloader)
        for i_batch, sample_batched in enumerate(train_dataloader):
            inputs, target = sample_batched[0].to(
                self.opt.device), sample_batched[1].to(self.opt.device)
            predict = self.model(inputs)

            optimizer.zero_grad()
            loss = criterion(predict, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * len(sample_batched)
            n_total += len(sample_batched)

            ratio = int((i_batch + 1) * 50 / n_batch)
            sys.stdout.write("\r[" + ">" * ratio + " " * (50 - ratio) +
                             "] {}/{} {:.2f}%".format(i_batch + 1, n_batch,
                                                      (i_batch + 1) * 100 /
                                                      n_batch))
            sys.stdout.flush()
        print()
        return train_loss / n_total

    def _evaluation(self, val_dataloader, criterion):
        self.model.eval()
        val_loss, val_dice, n_total = 0, 0, 0
        with torch.no_grad():
            for sample_batched in val_dataloader:
                inputs, target = sample_batched[0].to(
                    self.opt.device), sample_batched[1].to(self.opt.device)
                predict = self.model(inputs)
                loss = criterion(predict, target)
                dice = dice_coeff(predict, target)
                val_loss += loss.item() * len(sample_batched)
                val_dice += dice.item() * len(sample_batched)
                n_total += len(sample_batched)
        return val_loss / n_total, val_dice / n_total

    def run(self):
        _params = filter(lambda p: p.requires_grad, self.model.parameters())
        optimizer = torch.optim.Adam(_params,
                                     lr=self.opt.lr,
                                     weight_decay=self.opt.l2reg)
        criterion = BCELoss2d()
        train_dataloader = DataLoader(dataset=self.trainset,
                                      batch_size=self.opt.batch_size,
                                      shuffle=True)
        val_dataloader = DataLoader(dataset=self.valset,
                                    batch_size=self.opt.batch_size,
                                    shuffle=False)
        self._reset_records()
        for epoch in range(self.opt.num_epoch):
            train_loss = self._train(train_dataloader, criterion, optimizer)
            val_loss, val_dice = self._evaluation(val_dataloader, criterion)
            self._update_records(epoch, train_loss, val_loss, val_dice)
            print(
                '{:d}/{:d} > train loss: {:.4f}, val loss: {:.4f}, val dice: {:.4f}'
                .format(epoch + 1, self.opt.num_epoch, train_loss, val_loss,
                        val_dice))
        self._draw_records()

    def inference(self):
        test_dataloader = DataLoader(dataset=self.testset,
                                     batch_size=1,
                                     shuffle=False)
        n_batch = len(test_dataloader)
        with torch.no_grad():
            for i_batch, sample_batched in enumerate(test_dataloader):
                index, inputs = sample_batched[0], sample_batched[1].to(
                    self.opt.device)
                predict = self.model(inputs)
                self.testset.save_img(index.item(), predict, self.opt.use_crf)
                ratio = int((i_batch + 1) * 50 / n_batch)
                sys.stdout.write(
                    "\r[" + ">" * ratio + " " * (50 - ratio) +
                    "] {}/{} {:.2f}%".format(i_batch + 1, n_batch,
                                             (i_batch + 1) * 100 / n_batch))
                sys.stdout.flush()
        print()
Exemplo n.º 22
0
def train(config, dataloader_train, dataloader_test=None):
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    if config.pretrain:
        # 3 classes one for each color channel
        num_classes = 3
        model = UNet(n_channels=3, n_classes=num_classes)
        model.pretrain(True)
        loss_func = F.mse_loss
    else:
        # 3 classes one for each color channel and 2 for segmentation
        num_classes = 5
        model = UNet(n_channels=3, n_classes=num_classes)
        model.pretrain(False)
        # 0.3 and 0.7 were calculated by checking the ratio of pixels that give
        loss_func = CombinedLoss(
            weight=torch.tensor([0.3, 0.7], device="cuda"))

    # move model to the right device
    model.to(device)

    if config.load_pretrain:
        try:
            pretrained_dict = torch.load(config.pretrain_weight_path,
                                         map_location=device)

            # 1. filter out unnecessary keys
            model_dict = model.state_dict()
            pretrained_dict = {
                k: v
                for k, v in pretrained_dict.items() if "outc" not in k
            }

            # 2. overwrite entries in the existing state dict
            model_dict.update(pretrained_dict)

            # 3. load the new state dict
            model.load_state_dict(model_dict)
        except Exception as e:
            print("Could not load weights")
            raise e

    # construct an optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(params)

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=3,
                                                   gamma=0.1)
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    for epoch in range(config.num_epochs):
        train_loss, train_accuracy = train_one_epoch(
            model,
            optimizer,
            loss_func,
            dataloader_train,
            device,
            epoch,
            print_freq=config.print_freq)
        print(
            f"Epoch =  {epoch} \t Train loss = {train_loss} \t Train Accuracy = {train_accuracy}"
        )
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)

        # update the learning rate
        lr_scheduler.step()
        if config.pretrain:
            model_name = config.snapshots_folder + f"/pretrain-{epoch}.pt"
            torch.save(model.state_dict(), model_name)
            continue
        else:
            model_name = config.snapshots_folder + f"/model-{epoch}.pt"
            torch.save(model.state_dict(), model_name)

        val_loss, val_accuracy = evaluate(model,
                                          dataloader_test,
                                          loss_func,
                                          device=device,
                                          print_freq=20,
                                          acc_func=calc_accuracy,
                                          iters=20)
        print(
            f"Epoch =  {epoch} \t Validation loss = {val_loss} \t Validation Accuracy = {val_accuracy}"
        )
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)

    plt.plot(train_accuracies, label="train accuracy")
    plt.plot(val_accuracies, label="val accuracy")
    plt.legend()
    plt.savefig(config.snapshots_folder + "/accuracy_plot.jpg")

    plt.plot(train_losses, label="train loss")
    plt.plot(val_losses, label="val loss")
    plt.legend()
    plt.savefig(config.snapshots_folder + "/loss_plot.jpg")
Exemplo n.º 23
0
# image_numpy = image.numpy()
# mask_numpy = mask.numpy()
# print(image_numpy.shape)
# print(mask_numpy.shape)
# image_numpy_transpose = np.transpose(image_numpy, (1, 2, 0))
# mask_numpy_transpose = np.transpose(mask_numpy, (1, 2, 0))
# print(image_numpy_transpose.shape)
# print(mask_numpy_transpose.shape)
# plt.imshow(image_numpy_transpose)
# plt.show()
# plt.imshow(mask_numpy_transpose.squeeze(), cmap='gray')
# plt.show()

unet = UNet(in_channels=Brain_Segmentation_Dataset.in_channels, out_channels=Brain_Segmentation_Dataset.out_channels)
print(unet)
unet.to(device)

dsc_loss = DiceLoss()
best_validation_dsc = 0.0
optimizer = optim.Adam(unet.parameters(), lr=args.lr)

loss_train = []
loss_valid = []

step = 0

for epoch in range(args.epochs):
    for phase in ['train', 'valid']:
        if phase == 'train':
            unet.train()
        else:
Exemplo n.º 24
0
def load_model(model_path, device):
    model = UNet(3, 3)
    model.load_state_dict(
        torch.load(model_path, map_location=torch.device(device)))
    model.to(device)
    return model
Exemplo n.º 25
0
def train(cfg_path, device='cuda'):
    if cfg_path is not None:
        cfg.merge_from_file(cfg_path)
    cfg.freeze()

    if not os.path.isdir(cfg.LOG_DIR):
        os.makedirs(cfg.LOG_DIR)
    if not os.path.isdir(cfg.SAVE_DIR):
        os.makedirs(cfg.SAVE_DIR)

    model = UNet(cfg.NUM_CHANNELS, cfg.NUM_CLASSES)
    model.to(device)

    train_data_loader = build_data_loader(cfg, 'train')
    if cfg.VAL:
        val_data_loader = build_data_loader(cfg, 'val')
    else:
        val_data_loader = None

    optimizer = build_optimizer(cfg, model)
    lr_scheduler = build_lr_scheduler(cfg, optimizer)
    criterion = get_loss_func(cfg)
    writer = SummaryWriter(cfg.LOG_DIR)

    iter_counter = 0
    loss_meter = AverageMeter()
    val_loss_meter = AverageMeter()
    min_val_loss = 1e10

    print('Training Start')
    for epoch in range(cfg.SOLVER.MAX_EPOCH):
        print('Epoch {}/{}'.format(epoch + 1, cfg.SOLVER.MAX_EPOCH))
        if lr_scheduler is not None:
            lr_scheduler.step(epoch)
        for data in train_data_loader:
            iter_counter += 1

            imgs, annots = data
            imgs = imgs.to(device)
            annots = annots.to(device)

            y = model(imgs)
            optimizer.zero_grad()
            loss = criterion(y, annots)
            loss.backward()
            optimizer.step()
            loss_meter.update(loss.item())

            if iter_counter % 10 == 0:
                writer.add_scalars('loss', {'train': loss_meter.avg},
                                   iter_counter)
                loss_meter.reset()
            if lr_scheduler is not None:
                writer.add_scalar('learning rate',
                                  optimizer.param_groups[0]['lr'],
                                  iter_counter)
            save_as_checkpoint(model, optimizer,
                               os.path.join(cfg.SAVE_DIR, 'checkpoint.pth'),
                               epoch, iter_counter)

        # Skip validation when cfg.VAL is False
        if val_data_loader is None:
            continue

        for data in val_data_loader:
            val_loss_meter.reset()
            with torch.no_grad():
                imgs, annots = data
                imgs = imgs.to(device)
                annots = annots.to(device)

                y = model(imgs)
                optimizer.zero_grad()
                loss = criterion(y, annots)
                val_loss_meter.update(loss.item())
        if val_loss_meter.avg < min_val_loss:
            min_val_loss = val_loss_meter.avg
            writer.add_scalars('loss', {'val': val_loss_meter.avg},
                               iter_counter)
            # save model if validation loss is minimum
            torch.save(model.state_dict(),
                       os.path.join(cfg.SAVE_DIR, 'min_val_loss.pth'))
Exemplo n.º 26
0
class Pipeline:
    net = None
    optimizer = None

    def __init__(self, model_architecture: str, device: torch.device):
        assert model_architecture in ['unet', 'mnet2']
        self.model_architecture = model_architecture
        self.device = device

    def create_net(self) -> nn.Module:
        if self.model_architecture == 'unet':
            self.net = UNet(n_channels=1, n_classes=1)
        elif self.model_architecture == 'mnet2':
            self.net = MobileNetV2_UNet()
        else:
            raise ValueError(
                f'model_architecture must be in ["unet", "mnet2"]. '
                f'passed: {self.model_architecture}')

        self.net.to(device=self.device)

        return self.net

    def create_optimizer(self) -> Optimizer:
        """
        It is important to create optimizer only after moving model to appropriate device
        as model's parameters will be different after changing the device.
        """

        # optimizer = optim.SGD(self.net.parameters(), lr=0.0001, momentum=0.9)
        self.optimizer = optim.Adam(self.net.parameters(), lr=1e-3)

        return self.optimizer

    def load_net_from_weights(self, checkpoint_fp: str):
        """load model parameters from checkpoint .pth file"""
        print(f'\nload_net_from_weights()')
        print(f'loading model parameters from "{checkpoint_fp}"')

        self.create_net()

        state_dict = torch.load(checkpoint_fp)
        self.net.load_state_dict(state_dict)

    def train(self,
              train_loader: BaseDataLoader,
              valid_loader: BaseDataLoader,
              n_epochs: int,
              loss_func: nn.Module,
              metrics: List[nn.Module],
              out_dp: str = None,
              max_batches: int = None,
              initial_checkpoint_fp: str = None):
        """
        Train wrapper.

        :param max_batches: maximum number of batches for training and validation to perform sanity check
        :param initial_checkpoint_fp: path to .pth checkpoint for warm start
        """

        out_dp = out_dp or const.RESULTS_DN
        # check if dir is nonempty
        utils.prompt_to_clear_dir_content_if_nonempty(out_dp)
        os.makedirs(out_dp, exist_ok=True)

        print(const.SEPARATOR)
        if initial_checkpoint_fp is not None:
            print('training with WARM START')
            self.load_net_from_weights(initial_checkpoint_fp)
        else:
            print('training with COLD START')
            self.create_net()

        self.create_optimizer()

        # consider providing the same tolerance to ReduceLROnPlateau and Early Stopping
        tolerance = 1e-4

        scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                         mode='min',
                                                         factor=0.2,
                                                         min_lr=1e-6,
                                                         threshold=tolerance,
                                                         patience=2,
                                                         threshold_mode='abs',
                                                         cooldown=0,
                                                         verbose=True)

        history = mu.train_valid(net=self.net,
                                 loss_func=loss_func,
                                 metrics=metrics,
                                 train_loader=train_loader,
                                 valid_loader=valid_loader,
                                 optimizer=self.optimizer,
                                 scheduler=scheduler,
                                 device=self.device,
                                 n_epochs=n_epochs,
                                 es_tolerance=tolerance,
                                 es_patience=15,
                                 out_dp=out_dp,
                                 max_batches=max_batches)

        # store history dict to .pickle file
        print(const.SEPARATOR)
        history_out_fp = f'{out_dp}/train_history_{history["loss_name"]}.pickle'
        print(f'storing train history dict to "{history_out_fp}"')
        with open(history_out_fp, 'wb') as fout:
            pickle.dump(history, fout)

        utils.print_cuda_memory_stats(self.device)

    # def evaluate_model(self):
    #     print(const.SEPARATOR)
    #     print('evaluate_model()')
    #
    #     if self.net is None:
    #         raise ValueError('must call train() or load_model() before evaluating')
    #
    #     hd, hd_avg = mu.get_hd_for_valid_slices(
    #         self.net, self.device, loss_name, self.indices_valid, self.scans_dp, self.masks_dp
    #     )
    #
    #     hd_list = [x[1] for x in hd]
    #     mu.build_hd_boxplot(hd_list, False, loss_name)
    #     mu.visualize_worst_best(self.net, hd, False, self.scans_dp, self.masks_dp, self.device, loss_name)
    #
    #     hd_avg_list = [x[1] for x in hd_avg]
    #     mu.build_hd_boxplot(hd_avg_list, True, loss_name)
    #     mu.visualize_worst_best(self.net, hd_avg, True, self.scans_dp, self.masks_dp, self.device, loss_name)

    def segment_scans(self,
                      checkpoint_fp: str,
                      scans_dp: str,
                      postfix: str,
                      ids: List[str] = None,
                      output_dp: str = None):
        """
        :param checkpoint_fp:   path to .pth file with net's params dict
        :param scans_dp:    path directory with .nii.gz scans.
                            will check that scans do not have any postfixes in their filenames.
        :param postfix:     postfix of segmented filenames
        :param ids:    list of image ids to consider. if None segment all scans under `scans_dp`
        :param output_dp:   path to directory to store results of segmentation
        """
        utils.check_var_to_be_iterable_collection(ids)

        print(const.SEPARATOR)
        print('Pipeline.segment_scans()')

        output_dp = output_dp or const.SEGMENTED_DN
        print(f'will store segmented masks under "{output_dp}"')
        os.makedirs(output_dp, exist_ok=True)

        print(f'postfix: {postfix}')

        self.load_net_from_weights(checkpoint_fp)
        scans_fps = utils.get_nii_gz_filepaths(scans_dp)
        print(f'# of .nii.gz files under "{scans_dp}": {len(scans_fps)}')

        # filter filepaths to scans
        scans_fps_filtered = []
        for fp in scans_fps:
            img_id, img_postfix = utils.parse_image_id_from_filepath(
                fp, get_postfix=True)
            if img_postfix != '' or ids is not None and img_id not in ids:
                continue
            scans_fps_filtered.append(fp)
        print(f'# of scans left after filtering: {len(scans_fps_filtered)}')

        print('\nstarting segmentation...')
        time_start_segmentation = time.time()

        with tqdm.tqdm(total=len(scans_fps_filtered)) as pbar:
            for fp in scans_fps_filtered:
                cur_id = utils.parse_image_id_from_filepath(fp)
                pbar.set_description(cur_id)

                scan_nifti, scan_data = utils.load_nifti(fp)

                # clip intensities as during training
                scan_data_clipped = preprocessing.clip_intensities(scan_data)

                segmented_data = mu.segment_single_scan(
                    scan_data_clipped, self.net, self.device)
                segmented_nifti = utils.change_nifti_data(segmented_data,
                                                          scan_nifti,
                                                          is_scan=False)

                out_fp = os.path.join(output_dp, f'{cur_id}_{postfix}.nii.gz')
                utils.store_nifti_to_file(segmented_nifti, out_fp)

                pbar.update()

        print(
            f'\nsegmentation ended. elapsed time: {utils.get_elapsed_time_str(time_start_segmentation)}'
        )
        utils.print_cuda_memory_stats(self.device)

    def lr_find_and_store(self,
                          loss_func: nn.Module,
                          train_loader: BaseDataLoader,
                          out_dp: str = None):
        """
        LRFinder wrapper.
        """
        out_dp = out_dp or os.path.join(const.RESULTS_DN,
                                        const.LR_FINDER_RESULTS_DN)
        os.makedirs(out_dp, exist_ok=True)
        utils.prompt_to_clear_dir_content_if_nonempty(out_dp)

        self.create_net()
        self.create_optimizer()

        lr_finder = LRFinder(net=self.net,
                             loss_func=loss_func,
                             optimizer=self.optimizer,
                             train_loader=train_loader,
                             device=self.device)
        lr_finder.lr_find()
        lr_finder.store_results(out_dp=out_dp)
Exemplo n.º 27
0
class Solver:
    def __init__(self,
                 config,
                 train_loader=None,
                 val_loader=None,
                 test_loader=None):
        self.cfg = config
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.n_gpus = self.cfg.n_gpus

        if self.cfg.mode in ['train', 'test']:
            self.train_loader = train_loader
            self.val_loader = val_loader
        else:
            self.test_loader = test_loader

        # Build model
        self.build_model()
        if self.cfg.resume:
            self.load_pre_model()
        else:
            self.start_epoch = 0

        # Trigger Tensorboard Logger
        if self.cfg.use_tensorboard:
            try:
                from tensorboardX import SummaryWriter
                self.writer = SummaryWriter()
            except ImportError:
                print(
                    '=> There is no module named tensorboardX, tensorboard disabled'
                )
                self.cfg.use_tensorboard = False

    def train_val(self):
        # Build record objs
        self.build_recorder()

        iter_per_epoch = len(
            self.train_loader.dataset) // self.cfg.train_batch_size
        if len(self.train_loader.dataset) % self.cfg.train_batch_size != 0:
            iter_per_epoch += 1

        for epoch in range(self.start_epoch,
                           self.start_epoch + self.cfg.n_epochs):

            self.model.train()

            self.train_time.reset()
            self.train_loss.reset()
            self.train_cls_acc.reset()
            self.train_pix_acc.reset()
            self.train_mIoU.reset()

            for i, (image, label) in enumerate(self.train_loader):
                start_time = time.time()
                image_var = image.to(self.device)
                label_var = label.to(self.device)

                output = self.model(image_var)
                loss = self.criterion(output, label_var)

                self.optim.zero_grad()
                loss.backward()
                self.optim.step()

                end_time = time.time()

                self.train_time.update(end_time - start_time)
                self.train_loss.update(loss.item())

                if self.cfg.task == 'cls':
                    # Record classification accuracy
                    cls_acc = cal_acc(output, label_var)

                    # Update recorder
                    self.train_cls_acc.update(cls_acc.item())

                    if (i + 1) % self.cfg.log_step == 0:
                        print(
                            'Epoch[{0}][{1}/{2}]\t'
                            'Time {train_time.val:.3f} ({train_time.avg:.3f})\t'
                            'Loss {train_loss.val:.4f} ({train_loss.avg:.4f})\t'
                            'Accuracy {train_cls_acc.val:.4f} ({train_cls_acc.avg:.4f})'
                            .format(epoch + 1,
                                    i + 1,
                                    iter_per_epoch,
                                    train_time=self.train_time,
                                    train_loss=self.train_loss,
                                    train_cls_acc=self.train_cls_acc))

                    if self.cfg.use_tensorboard:
                        self.writer.add_scalar('train/loss', loss.item(),
                                               epoch * iter_per_epoch + i)
                        self.writer.add_scalar('train/accuracy',
                                               cls_acc.item(),
                                               epoch * iter_per_epoch + i)

                elif self.cfg.task == 'seg':
                    # Record mIoU and pixel-wise accuracy
                    pix_acc = cal_pixel_acc(output, label_var)
                    mIoU = cal_mIoU(output, label_var)[-1]
                    mIoU = torch.mean(mIoU)

                    # Update recorders
                    self.train_pix_acc.update(pix_acc.item())
                    self.train_mIoU.update(mIoU.item())

                    if (i + 1) % self.cfg.log_step == 0:
                        print(
                            'Epoch[{0}][{1}/{2}]\t'
                            'Time {train_time.val:.3f} ({train_time.avg:.3f})\t'
                            'Loss {train_loss.val:.4f} ({train_loss.avg:.4f})\t'
                            'Pixel-Acc {train_pix_acc.val:.4f} ({train_pix_acc.avg:.4f})\t'
                            'mIoU {train_mIoU.val:.4f} ({train_mIoU.avg:.4f})'.
                            format(epoch + 1,
                                   i + 1,
                                   iter_per_epoch,
                                   train_time=self.train_time,
                                   train_loss=self.train_loss,
                                   train_pix_acc=self.train_pix_acc,
                                   train_mIoU=self.train_mIoU))

                    if self.cfg.use_tensorboard:
                        self.writer.add_scalar('train/loss', loss.item(),
                                               epoch * iter_per_epoch + i)
                        self.writer.add_scalar('train/pix_acc', pix_acc.item(),
                                               epoch * iter_per_epoch + i)
                        self.writer.add_scalar('train/mIoU', mIoU.item(),
                                               epoch * iter_per_epoch + i)

                #FIXME currently test validation code
                #if (i + 1) % 100 == 0:
            if (epoch + 1) % self.cfg.val_step == 0:
                self.validate(epoch)

        # Close logging
        self.writer.close()

    def validate(self, epoch):
        """ Validate with validation dataset """
        self.model.eval()

        self.val_time.reset()
        self.val_loss.reset()
        self.val_cls_acc.reset()
        self.val_mIoU.reset()
        self.val_pix_acc.reset()

        iter_per_epoch = len(
            self.val_loader.dataset) // self.cfg.val_batch_size
        if len(self.val_loader.dataset) % self.cfg.val_batch_size != 0:
            iter_per_epoch += 1

        for i, (image, label) in enumerate(self.val_loader):

            start_time = time.time()
            image_var = image.to(self.device)
            label_var = label.to(self.device)

            output = self.model(image_var)
            loss = self.criterion(output, label_var)

            end_time = time.time()

            self.val_time.update(end_time - start_time)
            self.val_loss.update(loss.item())

            if self.cfg.task == 'cls':
                # Record classification accuracy
                cls_acc = cal_acc(output, label_var)

                # Update recorder
                self.val_cls_acc.update(cls_acc.item())

                if (i + 1) % self.cfg.log_step == 0:
                    print(
                        'Epoch[{0}][{1}/{2}]\t'
                        'Time {val_time.val:.3f} ({val_time.avg:.3f})\t'
                        'Loss {val_loss.val:.4f} ({val_loss.avg:.4f})\t'
                        'Accuracy {val_cls_acc.val:.4f} ({val_cls_acc.avg:.4f})'
                        .format(epoch + 1,
                                i + 1,
                                iter_per_epoch,
                                val_time=self.val_time,
                                val_loss=self.val_loss,
                                val_cls_acc=self.val_cls_acc))

                if self.cfg.use_tensorboard:
                    self.writer.add_scalar('val/loss', loss.item(),
                                           epoch * iter_per_epoch + i)
                    self.writer.add_scalar('val/accuracy', cls_acc.item(),
                                           epoch * iter_per_epoch + i)

            elif self.cfg.task == 'seg':
                # Record mIoU and pixel-wise accuracy
                pix_acc = cal_pixel_acc(output, label_var)
                mIoU = cal_mIoU(output, label_var)[-1]
                mIoU = torch.mean(mIoU)

                # Update recorders
                self.val_pix_acc.update(pix_acc.item())
                self.val_mIoU.update(mIoU.item())

                if (i + 1) % self.cfg.log_step == 0:
                    print(
                        ' ##### Validation\t'
                        'Epoch[{0}][{1}/{2}]\t'
                        'Time {val_time.val:.3f} ({val_time.avg:.3f})\t'
                        'Loss {val_loss.val:.4f} ({val_loss.avg:.4f})\t'
                        'Pixel-Acc {val_pix_acc.val:.4f} ({val_pix_acc.avg:.4f})\t'
                        'mIoU {val_mIoU.val:.4f} ({val_mIoU.avg:.4f})'.format(
                            epoch + 1,
                            i + 1,
                            iter_per_epoch,
                            val_time=self.val_time,
                            val_loss=self.val_loss,
                            val_pix_acc=self.val_pix_acc,
                            val_mIoU=self.val_mIoU))

                if self.cfg.use_tensorboard:
                    self.writer.add_scalar('val/loss', loss.item(),
                                           epoch * iter_per_epoch + i)
                    self.writer.add_scalar('val/pix_acc', pix_acc.item(),
                                           epoch * iter_per_epoch + i)
                    self.writer.add_scalar('val/mIoU', mIoU.item(),
                                           epoch * iter_per_epoch + i)

        if self.cfg.task == 'cls':
            if (epoch + 1) % self.cfg.model_save_epoch == 0:
                state = {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optim': self.optim.state_dict()
                }
                if self.best_cls < self.val_cls_acc.avg:
                    self.best_cls = self.val_cls_acc.avg
                    torch.save(
                        state, './model/cls_model_' + str(epoch + 1) + '_' +
                        str(self.val_cls_acc.avg)[0:5] + '.pth')

        elif self.cfg.task == 'seg':
            # Save segmentation samples and model
            if (epoch + 1) % self.cfg.sample_save_epoch == 0:
                pred = torch.argmax(output, dim=1)
                save_image(image, './sample/ori_' + str(epoch + 1) + '.png')
                save_image(label.unsqueeze(1),
                           './sample/true_' + str(epoch + 1) + '.png')
                save_image(pred.cpu().unsqueeze(1),
                           './sample/pred_' + str(epoch + 1) + '.png')

            if (epoch + 1) % self.cfg.model_save_epoch == 0:
                state = {
                    'epoch': epoch + 1,
                    'state_dict': self.model.state_dict(),
                    'optim': self.optim.state_dict()
                }
                if self.best_seg < self.val_pix_acc.avg:
                    self.best_seg = self.val_pix_acc.avg
                    torch.save(
                        state, './model/seg_model_' + str(epoch + 1) + '_' +
                        str(self.val_pix_acc.avg)[0:5] + '.pth')

            if self.cfg.use_tensorboard:
                image = make_grid(image)
                label = make_grid(label.unsqueeze(1))
                pred = make_grid(pred.cpu().unqueeze(1))
                self.writer.add_image('Origianl', image, epoch + 1)
                self.writer.add_image('Labels', label, epoch + 1)
                self.writer.add_image('Predictions', pred, epoch + 1)

    def build_model(self):
        """ Rough """
        if self.cfg.task == 'cls':
            self.model = BinaryClassifier(num_classes=2)
        elif self.cfg.task == 'seg':
            self.model = UNet(num_classes=2)
        self.criterion = nn.CrossEntropyLoss()
        self.optim = torch.optim.Adam(self.model.parameters(),
                                      lr=self.cfg.lr,
                                      betas=(self.cfg.beta0, self.cfg.beta1))
        if self.n_gpus > 1:
            print('### {} of gpus are used!!!'.format(self.n_gpus))
            self.model = nn.DataParallel(self.model)

        self.model = self.model.to(self.device)

    def build_recorder(self):
        # Train recorder
        self.train_time = AverageMeter()
        self.train_loss = AverageMeter()

        # For classification
        self.train_cls_acc = AverageMeter()
        # For segmentation
        self.train_mIoU = AverageMeter()
        self.train_pix_acc = AverageMeter()

        # Validation recorder
        self.val_time = AverageMeter()
        self.val_loss = AverageMeter()

        # For classification
        self.val_cls_acc = AverageMeter()
        # For segmentation
        self.val_mIoU = AverageMeter()
        self.val_pix_acc = AverageMeter()

        # self.logger = Logger('./logs')
        self.best_cls = 0
        self.best_seg = 0

    def load_pre_model(self):
        """ Load pretrained model """
        print('=> loading checkpoint {}'.format(self.cfg.pre_model))
        checkpoint = torch.load(self.cfg.pre_model)
        self.start_epoch = checkpoint['epoch']
        self.model.load_state_dict(checkpoint['state_dict'])
        self.optim.load_state_dict(checkpoint['optim'])
        print('=> loaded checkpoint {}(epoch {})'.format(
            self.cfg.pre_model, self.start_epoch))

    #TODO:Inference part:
    def infer(self, data):
        """
        input
            @data: iterable 256 x 256 patches
        output
            @output : segmentation results from each patch
                    i) If classifier's result is that there is a tissue inside of patch, outcome is a masked result.
                    ii) Otherwise, output is segmentated mask which all of pixels are background
        """
        # Data Loading

        # Load models of classification and segmetation and freeze them
        self.freeze()

        # Forward images to Classification model / Select targeted images

        # Forward images to Segmentation model

        # Record Loss / Accuracy / Pixel-Accuracy

        # Print samples out..

    def freeze(self):
        pass
        print('{}, {} have frozen!!!'.format('model_name_1', 'model_name_2'))
Exemplo n.º 28
0
def main(args):
    writer = SummaryWriter(os.path.join('./logs'))
    # torch.backends.cudnn.benchmark = False
    # if not os.path.isdir(args.checkpoint_dir):
    #     os.mkdir(args.checkpoint_dir)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('[MODEL] CUDA DEVICE : {}'.format(device))

    # TODO DEFINE TRAIN AND TEST TRANSFORMS
    train_tf = None
    test_tf = None

    # Channel wise mean calculated on adobe240-fps training dataset
    mean = [0.429, 0.431, 0.397]
    std = [1, 1, 1]
    normalize = transforms.Normalize(mean=mean,
                                     std=std)
    transform = transforms.Compose([transforms.ToTensor(), normalize])

    test_valid = 'validation' if args.valid else 'test'
    # train_data = BlurDataset(os.path.join(args.dataset_root, 'train'),
    #                         seq_len=args.sequence_length, tau=args.num_frame_blur, delta=5, transform=train_tf)
    test_data = BlurDataset(os.path.join(args.dataset_root, test_valid),
                            seq_len=args.sequence_length, tau=args.num_frame_blur, delta=5, transform=train_tf, return_path=True)

    # train_loader = DataLoader(train_data, batch_size=args.train_batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False)

    # TODO IMPORT YOUR CUSTOM MODEL
    model = UNet(3, 3, device, decode_mode=args.decode_mode)

    if args.checkpoint:
        store_dict = torch.load(args.checkpoint)
        try:
            print('Loading checkpoint...')
            model.load_state_dict(store_dict['state_dict'])
            print('Done.')
        except KeyError:
            print('Loading checkpoint...')
            model.load_state_dict(store_dict)
            print('Done.')

    model.to(device)
    model.train(False)

    # model = nn.DataParallel(model)

    # TODO DEFINE MORE CRITERIA
    # input(True if device == torch.device('cuda:0') else False)
    criterion = {
                  'MSE': nn.MSELoss(),
                  'L1' : nn.L1Loss(),
                  # 'Perceptual': PerceptualLoss(model='net-lin', net='vgg', dataparallel=False,
                  #                             use_gpu=True if device == torch.device('cuda:0') else False)
                  }


    # Validation
    running_loss_test = 0.0
    psnr_test = 0.0
    dssim_test = 0.0

    tqdm_loader_test = tqdm(range(len(test_loader)), ncols=150)

    loss_tracker_test = {}
    for loss_fn in criterion.keys():
        loss_tracker_test[loss_fn] = 0.0

    with torch.no_grad():
        model.eval()
        total_steps_test = 0.0
        interp_idx = int(math.ceil((args.num_frame_blur / 2) - 0.49))
        for test_idx, data in enumerate(test_loader, 1):
            loss = 0.0
            blur_data, sharpe_data, sharp_names = data
            import pdb; pdb.set_trace()
            interp_idx = int(math.ceil((args.num_frame_blur / 2) - 0.49))
            # input(interp_idx)
            if args.decode_mode == 'interp':
                sharpe_data = sharpe_data[:, :, 1::2, :, :]
            elif args.decode_mode == 'deblur':
                sharpe_data = sharpe_data[:, :, 0::2, :, :]
            else:
                # print('\nBoth\n')
                sharpe_data = sharpe_data

            # print(sharpe_data.shape)
            # input(blur_data.shape)
            blur_data = blur_data.to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)
            try:
                sharpe_data = sharpe_data.squeeze().to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)
            except:
                sharpe_data = sharpe_data.squeeze(3).to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)

            # forward pass
            sharpe_out = model(blur_data).float()

            # compute losses
            sharpe_out = sharpe_out.permute(0, 2, 1, 3, 4)
            B, C, S, Fx, Fy = sharpe_out.shape
            for loss_fn in criterion.keys():
                loss_tmp = 0.0
                if loss_fn == 'Perceptual':
                    for bidx in range(B):
                        loss_tmp += criterion_w[loss_fn] * \
                                    criterion[loss_fn](sharpe_out[bidx].permute(1, 0, 2, 3),
                                                       sharpe_data[bidx].permute(1, 0, 2, 3)).sum()
                    # loss_tmp /= B
                else:
                    loss_tmp = criterion_w[loss_fn] * \
                               criterion[loss_fn](sharpe_out, sharpe_data)
                loss += loss_tmp
                try:
                    loss_tracker_test[loss_fn] += loss_tmp.item()
                except KeyError:
                    loss_tracker_test[loss_fn] = loss_tmp.item()

            # statistics
            #sharpe_out = sharpe_out.detach().cpu().numpy()
            #sharpe_data = sharpe_data.cpu().numpy()
            #  import pdb; pdb.set_trace()
            # t_grid = torchvision.utils.make_grid(torch.stack([blur_data[0], sharpe_out[0], sharpe_data[0]], dim=0),
            #                                    nrow=3)
            # tsave(t_grid, './imgs/{}/combined.jpg'.format(test_idx))
            for sidx in range(S):
                for bidx in range(B):
                    if not os.path.exists('./imgs/{}'.format(sharp_names[1])):
                        os.makedirs('./imgs/{}'.format(test_idx))
                    blur_path = './imgs/{}/blur_input_{}.jpg'.format(test_idx, sidx)

                    # import pdb; pdb.set_trace()
                    # torchvision.utils.save_image(sharpe_out[bidx, :, sidx, :, :],blur_path, normalize=True, range=(0,255));

                    imsave(blur_data, blur_path, bidx, sidx)

                    sharp_path = './imgs/{}/sharpe_gt_{}{}.jpg'.format(test_idx, sidx, sidx)
                    imsave(sharpe_data, sharp_path, bidx, sidx)

                    deblur_path = './imgs/{}/out_{}{}.jpg'.format(test_idx, sidx, sidx)
                    imsave(sharpe_out, deblur_path, bidx, sidx)

                    if sidx > 0 and sidx < S:
                        interp_path = './imgs/{}/out_{}{}.jpg'.format(test_idx, sidx-1, sidx)
                        imsave(sharpe_out, interp_path, bidx, sidx)
                        sharp_path = './imgs/{}/sharpe_gt_{}{}.jpg'.format(test_idx, sidx-1, sidx)
                        imsave(sharpe_data, sharp_path, bidx, sidx)

                    psnr_local = psnr(im_nm * sharpe_out[bidx, :, sidx, :, :].detach().cpu().numpy(),
                                      im_nm * sharpe_data[bidx, :, sidx, :, :].cpu().numpy())
                    dssim_local = dssim(np.moveaxis(im_nm * sharpe_out[bidx, :, sidx, :, :].cpu().numpy(), 0, 2),
                                        np.moveaxis(im_nm * sharpe_data[bidx, :, sidx, :, :].cpu().numpy(), 0, 2)
                                        )
                    psnr_test += psnr_local
                    dssim_test += dssim_local
            f = open('./imgs/{0}/psnr-{1:.4f}-dssim-{2:.4f}.txt'.format(test_idx, psnr_local/(B), dssim_local/(B)),'w')
            f.close()
            running_loss_test += loss.item()
            total_steps_test += B*S
            loss_str = ''
            for key in loss_tracker_test.keys():
                loss_str += ' {0} : {1:6.4f} '.format(key, 1.0 * loss_tracker_test[key] / total_steps_test)

            # set display info

            tqdm_loader_test.set_description(
                        ('\r[Test    ] loss: {0:6.4f} PSNR: {1:6.4f} SSIM: {2:6.4f} '.format
                         ( running_loss_test / total_steps_test,
                          psnr_test / total_steps_test,
                          dssim_test / total_steps_test
                          ) + loss_str
                         )
                    )
            tqdm_loader_test.update(1)
        tqdm_loader_test.close()
    return None
Exemplo n.º 29
0
from tversky_loss import TverskyLoss
import wandb
import time
import tqdm
from torch.optim.lr_scheduler import LambdaLR
from DiceLoss import DiceBCELoss

device = torch.device('cuda' if True else 'cpu')
print(f"Using {device}")

dataset = np.load('data_pub.zip')


model = UNet(in_dim=4, out_dim=4, num_filters=1)
print(f'Initialized Model w/ : {sum(p.numel() for p in model.parameters() if p.requires_grad)} params')
model.to(device)

epochs = 20
batch_size = 2
alpha= .5
beta = .5
lr = .05
use_wandb = True


if use_wandb:
    wandb.init(project="cs446", entity="weustis")
    wandb.watch(model)


crit = DiceBCELoss()
Exemplo n.º 30
0
def view_result():
    # input_img_path = args.input_img_path
    # model_path = args.model_path
    # save = args.view_path

    input_img_path = './view/good/result9_img.png'
    # input_img_path = './view/valid/images/0.png'
    model_path = 'view/cardia_sstu_net_train_best_model.pth'
    save = './view/result/SSTUnet/'
    # normalize = transforms.Normalize(mean=[0.156, 0.156, 0.156],
    #                                 std=[0.174, 0.174, 0.174])  # 不同的数据集需要自己计算。 head数据集
    normalize = transforms.Normalize(mean=[0.247, 0.260, 0.287],
                                     std=[0.189, 0.171, 0.121])  # 不同的数据集需要自己计算。   cardia_all数据集
    # normalize = transforms.Normalize(mean=[0.191, 0.191, 0.191],
    #                                 std=[0.218, 0.218, 0.218])  # 不同的数据集需要自己计算。   camus_256
    # transforms
    transform = transforms.Compose(
        [   #transforms.ToPILImage(),
            #transforms.RandomHorizontalFlip(p=0.5),
            #transforms.RandomVerticalFlip(p=0.5),
            #transforms.RandomGrayscale(p=0.2),  # 依概率p转为灰度图
            #transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),  # 修改修改亮度、对比度和饱和度
            transforms.ToTensor(),
            normalize
         ])
    device = torch.device('cuda')
    model = UNet(3,1)
    #model = ResUnet(3,1)
    #model = AttU_Net(3,1)
    #model = NestedUNet(3,1)

    pretrained_dict = torch.load(model_path)

    #print(pretrained_dict)

    model.load_state_dict(pretrained_dict)
    model.to(device)
    scale = (224, 224)
    img_o = Image.open(input_img_path).resize(scale)
    #img_o.show()

    img = transform(img_o).unsqueeze(0)

    print("img_np:",img.shape)
    #img_rensor= torch.from_numpy(img/255.0)

    #img = img_rensor.unsqueeze(0).unsqueeze(0).float()
    print("img:", img.size())
    img_ = img.to(device)
    with torch.no_grad():
        outputs = model(img_)
    outputs = torch.sigmoid(outputs)
    outputs = (outputs > 0.5).float()
    output_np = outputs.cpu().numpy().squeeze()
    print("output_np:", output_np.shape)
    name  = os.path.basename(input_img_path).split('.')[0]
    print(os.path.dirname(os.path.dirname(input_img_path)))
    #target= Image.open(os.path.dirname(os.path.dirname(input_img_path))+f'/mask/{name}_label.png').resize(scale)

    output = Image.fromarray((output_np*255).astype('uint8')).convert('1')
    #output.show()
    img_o.save(save + f'/{name}.png')
    output.save(save + f'/{name}_result.png')