Esempio n. 1
0
def build(config):
    global train_dataloader
    global val_dataloader
    global test_dataloader
    global model
    global loss_func
    global optimizer
    # ========= Build Data ==============
    if base_config['dataset'] == 'kaggle':
        from data import build_kaggle_dataset
        train_dataloader, val_dataloader, test_dataloader = build_kaggle_dataset(
            base_config)
    elif base_config['dataset'] == 'drive':
        from data import build_drive_dataset
        train_dataloader, val_dataloader, test_dataloader = build_drive_dataset(
            base_config)
    else:
        _logger.error('{} dataset is not supported now'.format(
            base_config['dataset']))
    # ======== Build Model
    if config['model'] == 'resnet101':
        from torchvision.models import resnet101
        model = resnet101(num_classes=base_config['n_classes'])
    elif config['model'] == 'resnext101':
        from torchvision.models import resnext101_32x8d
        model = resnext101_32x8d(num_classes=base_config['n_classes'])
    elif config['model'] == 'densenet':
        from torchvision.models import densenet121
        model = densenet121(num_classes=base_config['n_classes'])
    elif config['model'] == 'unet':
        from models import UNet
        model = UNet(num_classes=base_config['n_classes'])
    else:
        _logger.error('{} model is not supported'.format(config['model']))
    model = torch.nn.DataParallel(model.cuda())
    # Build optimizer
    if base_config['loss'] == 'ce':
        loss_func = torch.nn.CrossEntropyLoss().cuda()
    elif base_config['loss'] == 'bce':
        loss_func = torch.nn.BCELoss().cuda()
    elif base_config['loss'] == 'MSE':
        loss_func = torch.nn.MSELoss().cuda()
    else:
        _logger.error('{} loss is not supported'.format(config['loss']))
    if config['optimizer'] == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=config['lr'],
                                    momentum=0.9,
                                    weight_decay=5e-4)
    if config['optimizer'] == 'Adadelta':
        optimizer = torch.optim.Adadelta(model.parameters(), lr=config['lr'])
    if config['optimizer'] == 'Adagrad':
        optimizer = torch.optim.Adagrad(model.parameters(), lr=config['lr'])
    if config['optimizer'] == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
    if config['optimizer'] == 'Adamax':
        optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
Esempio n. 2
0
def train(device, model_path, dataset_path):
    """
    Trains the network according on the dataset_path
    """
    network = UNet(1, 3).to(device)
    optimizer = torch.optim.Adam(network.parameters())
    criteria = torch.nn.MSELoss()

    dataset = GrayColorDataset(dataset_path, transform=train_transform)
    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=16,
                                         shuffle=True,
                                         num_workers=cpu_count())

    if os.path.exists(model_path):
        network.load_state_dict(torch.load(model_path))
    for _ in tqdm.trange(10, desc="Epoch"):
        network.train()
        for gray, color in tqdm.tqdm(loader, desc="Training", leave=False):
            gray, color = gray.to(device), color.to(device)
            optimizer.zero_grad()
            pred_color = network(gray)
            loss = criteria(pred_color, color)
            loss.backward()
            optimizer.step()
        torch.save(network.state_dict(), model_path)
Esempio n. 3
0
def main(opt):
    writer = SummaryWriter()
    log_dir = writer.get_logdir()
    os.makedirs(os.path.join(log_dir, "images"), exist_ok=True)
    os.makedirs(os.path.join(log_dir, "test"), exist_ok=True)

    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")

    # Initialize generator and discriminator
    generator = UNet(opt.sample_num, opt.channels, opt.batch_size, opt.alpha)
    discriminator = Discriminator(opt.batch_size, opt.alpha)

    generator.to(device=device)
    discriminator.to(device=device)

    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=opt.lr_g,
                                   betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=opt.lr_d,
                                   betas=(opt.b1, opt.b2))

    if opt.mode == 'train':
        generator = train(writer, log_dir, device, generator, discriminator,
                          optimizer_G, optimizer_D, opt)
        test(opt, log_dir, generator=generator)
    if opt.mode == 'test':
        test(opt, log_dir)
        test_moving(opt, log_dir)
Esempio n. 4
0
def main():
    # Define the net
    net = UNet(in_channels=Dataset.in_channels,
               out_channels=Dataset.out_channels)  #.cuda()
    criterion = DiceLoss()
    optimizer = optim.Adam(net.parameters(), lr=1e-4)

    # Parpered the dataset
    dataloader = torch.utils.data.DataLoader(
        Dataset('/home/mooziisp/GitRepos/unet/data/membrane/train/image'),
        batch_size=1,
        shuffle=True,
        num_workers=2,
        pin_memory=True)

    # TODO validation
    # TODO acc and loss record
    # TODO intergated with tensorboard
    for epoch in range(100):
        for i, (images, targets) in enumerate(dataloader):
            #images, targets = images.cuda(), targets.cuda()

            preds = net(images)
            loss = criterion(preds, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(f'{epoch*30+i}it, loss: {loss.item():.3f}')

    return net
Esempio n. 5
0
def run(conf, data):
    score = ConfusionMatrix(num_classes)
    for file in os.listdir(data['splits_dir']):
        conf['log_key'] = file
        split = du.load_split_json(data['splits_dir'] + sep + file)
        model = UNet(conf['input_channels'], conf['num_classes'], reduce_by=2)
        optimizer = optim.Adam(model.parameters(), lr=conf['learning_rate'])

        if conf['distribute']:
            model = torch.nn.DataParallel(model)
            model.float()
            optimizer = optim.Adam(model.module.parameters(), lr=conf['learning_rate'])

        try:
            trainer = KernelTrainer(run_conf=conf, model=model, optimizer=optimizer)

            if conf.get('mode') == 'train':
                train_loader = KernelDataset.get_loader(shuffle=True, mode='train', transforms=transforms,
                                                        images=split['train'], data_conf=data,
                                                        run_conf=conf)

                validation_loader = KernelDataset.get_loader(shuffle=True, mode='validation',
                                                             transforms=transforms,
                                                             images=split['validation'], data_conf=data,
                                                             run_conf=conf)

                print('### Train Val Batch size:', len(train_loader), len(validation_loader))
                # trainer.resume_from_checkpoint(parallel_trained=conf.get('parallel_trained'), key='latest')
                trainer.train(train_loader=train_loader, validation_loader=validation_loader)

            test_loader = KernelDataset.get_loader(shuffle=False, mode='test', transforms=transforms,
                                                   images=split['test'], data_conf=data, run_conf=conf)

            trainer.resume_from_checkpoint(checkpoint_file= conf.get("checkpoint_file"), parallel_trained=conf.get('parallel_trained'), key='best')
            trainer.test(data_loader=test_loader, global_score=score, logger=trainer.test_logger)
        except Exception as e:
            traceback.print_exc()

    with open(conf.get('log_dir', 'net_logs') + os.sep + 'global_score.txt', 'w') as lg:
        lg.write(f'{score.precision()},{score.recall()},{score.f1()},{score.accuracy()}')
        lg.flush()
Esempio n. 6
0
            train_loss.append(
                epoch_loss) if phase == 'train' else valid_loss.append(
                    epoch_loss)

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    print('Best val loss: {:4f}'.format(best_loss))

    model.load_state_dict(best_model_wts)
    torch.save(model, (os.getcwd() + args.out_dir))
    return model, train_loss, valid_loss


optimizer_f = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                         lr=1e-3)
model, train_loss, valid_loss = train_model(model,
                                            optimizer_f,
                                            num_epochs=args.n_epochs)

if args.visualize:

    plt.figure(1, figsize=(10, 8))
    plt.plot(train_loss, label='Train loss')
    plt.plot(valid_loss, label='Valid loss')
    plt.legend()
    plt.show()

    model.eval()
Esempio n. 7
0
from torchvision.datasets import ImageFolder
import os
from models import UNet
from utils import Utils, TransForms
from traintest.traintest_unet import UNetTrainTest
from config import config

if __name__ == '__main__':
    tfms = TransForms(config.img_size)
    utils = Utils()
    train_dataset = ImageFolder(os.path.join('breast_cancer', 'train'),
                                transform=tfms.train_tfms)
    test_dataset = ImageFolder(os.path.join('breast_cancer', 'test'),
                               transform=tfms.test_tfms)
    model = UNet()
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(
        "Number of train images: {}\nNumber of test images: {}\nNumber of model trainable parameters: {}"
        .format(len(train_dataset.imgs), len(test_dataset.imgs), num_params))
    train_test = UNetTrainTest(config, model, train_dataset, test_dataset,
                               utils)
    train_test.train()
import torch.nn.functional as F
import torch.optim as optim
from . import config as cfg
from models import UNet

model_name = 'Base-UNet'
n_classes = 2
model = UNet(n_classes).to(cfg.device)
criterion = F.cross_entropy
optimizer = optim.Adam(model.parameters())
Esempio n. 9
0
def train_val(config):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    train_loader = get_dataloader(img_dir=config.train_img_dir,
                                  mask_dir=config.train_mask_dir,
                                  mode="train",
                                  batch_size=config.batch_size,
                                  num_workers=config.num_workers,
                                  smooth=config.smooth)
    val_loader = get_dataloader(img_dir=config.val_img_dir,
                                mask_dir=config.val_mask_dir,
                                mode="val",
                                batch_size=4,
                                num_workers=config.num_workers)

    writer = SummaryWriter(
        comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" %
        (config.lr, config.batch_size, config.model_type, config.data_type))

    if config.model_type == "UNet":
        model = UNet()
    elif config.model_type == "UNet++":
        model = UNetPP()
    elif config.model_type == "SEDANet":
        model = SEDANet()
    elif config.model_type == "RefineNet":
        model = rf101()
    elif config.model_type == "BASNet":
        model = BASNet(n_classes=8)
    elif config.model_type == "DANet":
        model = DANet(backbone='resnet101',
                      nclass=config.output_ch,
                      pretrained=True,
                      norm_layer=nn.BatchNorm2d)
    elif config.model_type == "Deeplabv3+":
        model = deeplabv3_plus.DeepLabv3_plus(in_channels=3,
                                              num_classes=8,
                                              backend='resnet101',
                                              os=16,
                                              pretrained=True,
                                              norm_layer=nn.BatchNorm2d)
    elif config.model_type == "HRNet_OCR":
        model = seg_hrnet_ocr.get_seg_model()
    elif config.model_type == "scSEUNet":
        model = scSEUNet(pretrained=True, norm_layer=nn.BatchNorm2d)
    else:
        model = UNet()

    if config.iscontinue:
        model = torch.load("./exp/24_Deeplabv3+_0.7825757691389714.pth").module

    for k, m in model.named_modules():
        m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatability

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.to(device)

    labels = [100, 200, 300, 400, 500, 600, 700, 800]
    objects = ['水体', '交通建筑', '建筑', '耕地', '草地', '林地', '裸土', '其他']

    if config.optimizer == "sgd":
        optimizer = SGD(model.parameters(),
                        lr=config.lr,
                        weight_decay=1e-4,
                        momentum=0.9)
    elif config.optimizer == "adamw":
        optimizer = adamw.AdamW(model.parameters(), lr=config.lr)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    # weight = torch.tensor([1, 1.5, 1, 2, 1.5, 2, 2, 1.2]).to(device)
    # criterion = nn.CrossEntropyLoss(weight=weight)

    criterion = BasLoss()

    # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[25, 30, 35, 40], gamma=0.5)
    # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.1, patience=5, verbose=True)
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
                                                         T_0=15,
                                                         eta_min=1e-4)

    global_step = 0
    max_fwiou = 0
    frequency = np.array(
        [0.1051, 0.0607, 0.1842, 0.1715, 0.0869, 0.1572, 0.0512, 0.1832])
    for epoch in range(config.num_epochs):
        epoch_loss = 0.0
        cm = np.zeros([8, 8])
        print(optimizer.param_groups[0]['lr'])
        with tqdm(total=config.num_train,
                  desc="Epoch %d / %d" % (epoch + 1, config.num_epochs),
                  unit='img',
                  ncols=100) as train_pbar:
            model.train()

            for image, mask in train_loader:
                image = image.to(device, dtype=torch.float32)
                mask = mask.to(device, dtype=torch.float16)

                pred = model(image)
                loss = criterion(pred, mask)
                epoch_loss += loss.item()

                writer.add_scalar('Loss/train', loss.item(), global_step)
                train_pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_pbar.update(image.shape[0])
                global_step += 1
                # if global_step > 10:
                #     break

            # scheduler.step()
            print("\ntraining epoch loss: " +
                  str(epoch_loss / (float(config.num_train) /
                                    (float(config.batch_size)))))
            torch.cuda.empty_cache()

        val_loss = 0
        with torch.no_grad():
            with tqdm(total=config.num_val,
                      desc="Epoch %d / %d validation round" %
                      (epoch + 1, config.num_epochs),
                      unit='img',
                      ncols=100) as val_pbar:
                model.eval()
                locker = 0
                for image, mask in val_loader:
                    image = image.to(device, dtype=torch.float32)
                    target = mask.to(device, dtype=torch.long).argmax(dim=1)
                    mask = mask.cpu().numpy()
                    pred, _, _, _, _, _, _, _ = model(image)
                    val_loss += F.cross_entropy(pred, target).item()
                    pred = pred.cpu().detach().numpy()
                    mask = semantic_to_mask(mask, labels)
                    pred = semantic_to_mask(pred, labels)
                    cm += get_confusion_matrix(mask, pred, labels)
                    val_pbar.update(image.shape[0])
                    if locker == 25:
                        writer.add_images('mask_a/true',
                                          mask[2, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_a/pred',
                                          pred[2, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_b/true',
                                          mask[3, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_b/pred',
                                          pred[3, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                    locker += 1

                    # break
                miou = get_miou(cm)
                fw_miou = (miou * frequency).sum()
                scheduler.step()

                if fw_miou > max_fwiou:
                    if torch.__version__ == "1.6.0":
                        torch.save(model,
                                   config.result_path + "/%d_%s_%.4f.pth" %
                                   (epoch + 1, config.model_type, fw_miou),
                                   _use_new_zipfile_serialization=False)
                    else:
                        torch.save(
                            model, config.result_path + "/%d_%s_%.4f.pth" %
                            (epoch + 1, config.model_type, fw_miou))
                    max_fwiou = fw_miou
                print("\n")
                print(miou)
                print("testing epoch loss: " + str(val_loss),
                      "FWmIoU = %.4f" % fw_miou)
                writer.add_scalar('mIoU/val', miou.mean(), epoch + 1)
                writer.add_scalar('FWIoU/val', fw_miou, epoch + 1)
                writer.add_scalar('loss/val', val_loss, epoch + 1)
                for idx, name in enumerate(objects):
                    writer.add_scalar('iou/val' + name, miou[idx], epoch + 1)
                torch.cuda.empty_cache()
    writer.close()
    print("Training finished")
Esempio n. 10
0
def train():
    transform = transforms.Compose([
        transforms.CenterCrop(256),
        transforms.ToTensor(),
    ])
    dataset = SpectralDataSet(
        root_dir=
        '/mnt/liguanlin/DataSets/lowlight_hyperspectral_datasets/band_splited_dataset',
        type_name='train',
        transform=transform)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    unet = UNet(input_dim, label_dim).to(device)
    unet.init_weight()
    unet_opt = torch.optim.Adam(unet.parameters(), lr=lr)
    scheduler = lr_scheduler.StepLR(unet_opt, step_size, gamma=0.4)
    cur_step = 0

    for epoch in range(n_epochs):
        train_l_sum, batch_count = 0.0, 0

        for real, labels in tqdm(dataloader):
            #print('real.shape', real.shape)
            #print('labels.shape', labels.shape)
            cur_batch_size = len(real)
            # Flatten the image
            real = real.to(device)
            labels = labels.to(device)

            ### Update U-Net ###
            unet_opt.zero_grad()
            pred = unet(real)
            #print('pred.shape', pred.shape)
            unet_loss = criterion(pred, labels)
            unet_loss.backward()
            unet_opt.step()

            train_l_sum += unet_loss.cpu().item()
            batch_count += 1

            if cur_step % display_step == 0:
                print(
                    f"Epoch {epoch}: Step {cur_step}: U-Net loss: {unet_loss.item()}"
                )
                """
                show_tensor_images(
                    real,
                    size=(input_dim, target_shape, target_shape)
                )
                print('labesl.shape:', labels.shape)
                print('pred.shape:', pred.shape)
                show_tensor_images(labels, size=(label_dim, target_shape, target_shape))
                show_tensor_images(torch.sigmoid(pred), size=(label_dim, target_shape, target_shape))
                """
            cur_step += 1

        if (epoch + 1) % 2 == 0:
            torch.save(unet.state_dict(),
                       './checkpoints/checkpoint_{}.pth'.format(epoch + 1))

        unet_opt.step()  #更新学习率

        print('epoch %d, train loss %.4f' %
              (epoch + 1, train_l_sum / batch_count))
Esempio n. 11
0
        model.load_state_dict(torch.load(pretrained_model_path))
        epoch_start = os.path.basename(pretrained_model_path).split('.')[0]
        print(epoch_start)

    trainLoader = DataLoader(DatasetImageMaskLocal(train_file_names,
                                                   object_type,
                                                   mode='train'),
                             batch_size=batch_size)
    devLoader = DataLoader(
        DatasetImageMaskLocal(val_file_names, object_type, mode='valid'))
    displayLoader = DataLoader(DatasetImageMaskLocal(val_file_names,
                                                     object_type,
                                                     mode='valid'),
                               batch_size=val_batch_size)

    optimizer = Adam(model.parameters(), lr=1e-4)

    criterion_global = LossMulti(num_classes=2, jaccard_weight=0)
    criterion_local = LossMulti(num_classes=2, jaccard_weight=0)

    for epoch in tqdm(
            range(int(epoch_start) + 1,
                  int(epoch_start) + 1 + no_of_epochs)):

        global_step = epoch * len(trainLoader)
        running_loss = 0.0
        running_loss_local = 0.0

        for i, (inputs, targets, coord) in enumerate(tqdm(trainLoader)):

            model.train()
    train_loader = make_loader(train_file_names,channels, shuffle=True, transform=train_transform, mode='train', batch_size = args.batch_size)  
    valid_loader = make_loader(val_file_names,channels, transform=val_transform, batch_size = args.batch_size, mode = "train")
  
    dataloaders = {
    'train': train_loader, 'val': valid_loader
    }

    dataloaders_sizes = {
    x: len(dataloaders[x]) for x in dataloaders.keys()
    }
    
    root.joinpath(('params_{}_{}.json').format(args.dataset_file,args.n_epochs)).write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))
    
    optimizer_ft = optim.Adam(model.parameters(), lr= args.lr)  #
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=20, gamma=0.1) 
    
    
    utilsTrain.train_model(
        dataset_file=args.dataset_file,
        name_file=name_file,
        model=model,
        optimizer=optimizer_ft,
        scheduler=exp_lr_scheduler,
        dataloaders=dataloaders,
        fold_out=args.fold_out,
        fold_in=args.fold_in,
        name_model=args.model,
        num_epochs=args.n_epochs 
        )
Esempio n. 13
0
                         transform=tr.ToTensor())

    pred_loader = DataLoader(dset_pred,
                             batch_size=args.test_batch_size,
                             shuffle=False,
                             num_workers=1)
    print("Prediction Data : ", len(pred_loader.dataset))

# %% Loading in the model
model = UNet()

if args.cuda:
    model.cuda()

if args.optimizer == 'SGD':
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.99)
if args.optimizer == 'ADAM':
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, args.beta2))

exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# Defining Loss Function
criterion = DICELossMultiClass()


def train(epoch, scheduler, loss_lsit):
    scheduler.step()
    model.train()
    for batch_idx, (image, mask) in enumerate(train_loader):
Esempio n. 14
0
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--jaccard-weight', type=float, default=1)
    arg('--root', type=str, default='runs/debug', help='checkpoint root')
    arg('--image-path', type=str, default='data', help='image path')
    arg('--batch-size', type=int, default=2)
    arg('--n-epochs', type=int, default=100)
    arg('--optimizer', type=str, default='Adam', help='Adam or SGD')
    arg('--lr', type=float, default=0.001)
    arg('--workers', type=int, default=10)
    arg('--model',
        type=str,
        default='UNet16',
        choices=[
            'UNet', 'UNet11', 'UNet16', 'LinkNet34', 'FCDenseNet57',
            'FCDenseNet67', 'FCDenseNet103'
        ])
    arg('--model-weight', type=str, default=None)
    arg('--resume-path', type=str, default=None)
    arg('--attribute',
        type=str,
        default='all',
        choices=[
            'pigment_network', 'negative_network', 'streaks',
            'milia_like_cyst', 'globules', 'all'
        ])
    args = parser.parse_args()

    ## folder for checkpoint
    root = Path(args.root)
    root.mkdir(exist_ok=True, parents=True)

    image_path = args.image_path

    #print(args)
    if args.attribute == 'all':
        num_classes = 5
    else:
        num_classes = 1
    args.num_classes = num_classes
    ### save initial parameters
    print('--' * 10)
    print(args)
    print('--' * 10)
    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))

    ## load pretrained model
    if args.model == 'UNet':
        model = UNet(num_classes=num_classes)
    elif args.model == 'UNet11':
        model = UNet11(num_classes=num_classes, pretrained='vgg')
    elif args.model == 'UNet16':
        model = UNet16(num_classes=num_classes, pretrained='vgg')
    elif args.model == 'LinkNet34':
        model = LinkNet34(num_classes=num_classes, pretrained=True)
    elif args.model == 'FCDenseNet103':
        model = FCDenseNet103(num_classes=num_classes)
    else:
        model = UNet(num_classes=num_classes, input_channels=3)

    ## multiple GPUs
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    model.to(device)

    ## load pretrained model
    if args.model_weight is not None:
        state = torch.load(args.model_weight)
        #epoch = state['epoch']
        #step = state['step']
        model.load_state_dict(state['model'])
        print('--' * 10)
        print('Load pretrained model', args.model_weight)
        #print('Restored model, epoch {}, step {:,}'.format(epoch, step))
        print('--' * 10)
        ## replace the last layer
        ## although the model and pre-trained weight have differernt size (the last layer is different)
        ## pytorch can still load the weight
        ## I found that the weight for one layer just duplicated for all layers
        ## therefore, the following code is not necessary
        # if args.attribute == 'all':
        #     model = list(model.children())[0]
        #     num_filters = 32
        #     model.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
        #     print('--' * 10)
        #     print('Load pretrained model and replace the last layer', args.model_weight, num_classes)
        #     print('--' * 10)
        #     if torch.cuda.device_count() > 1:
        #         model = nn.DataParallel(model)
        #     model.to(device)

    ## model summary
    print_model_summay(model)

    ## define loss
    loss_fn = LossBinary(jaccard_weight=args.jaccard_weight)

    ## It enables benchmark mode in cudnn.
    ## benchmark mode is good whenever your input sizes for your network do not vary. This way, cudnn will look for the
    ## optimal set of algorithms for that particular configuration (which takes some time). This usually leads to faster runtime.
    ## But if your input sizes changes at each iteration, then cudnn will benchmark every time a new size appears,
    ## possibly leading to worse runtime performances.
    cudnn.benchmark = True

    ## get train_test_id
    train_test_id = get_split()

    ## train vs. val
    print('--' * 10)
    print('num train = {}, num_val = {}'.format(
        (train_test_id['Split'] == 'train').sum(),
        (train_test_id['Split'] != 'train').sum()))
    print('--' * 10)

    train_transform = DualCompose(
        [HorizontalFlip(),
         VerticalFlip(),
         ImageOnly(Normalize())])

    val_transform = DualCompose([ImageOnly(Normalize())])

    ## define data loader
    train_loader = make_loader(train_test_id,
                               image_path,
                               args,
                               train=True,
                               shuffle=True,
                               transform=train_transform)
    valid_loader = make_loader(train_test_id,
                               image_path,
                               args,
                               train=False,
                               shuffle=True,
                               transform=val_transform)

    if True:
        print('--' * 10)
        print('check data')
        train_image, train_mask, train_mask_ind = next(iter(train_loader))
        print('train_image.shape', train_image.shape)
        print('train_mask.shape', train_mask.shape)
        print('train_mask_ind.shape', train_mask_ind.shape)
        print('train_image.min', train_image.min().item())
        print('train_image.max', train_image.max().item())
        print('train_mask.min', train_mask.min().item())
        print('train_mask.max', train_mask.max().item())
        print('train_mask_ind.min', train_mask_ind.min().item())
        print('train_mask_ind.max', train_mask_ind.max().item())
        print('--' * 10)

    valid_fn = validation_binary

    ###########
    ## optimizer
    if args.optimizer == 'Adam':
        optimizer = Adam(model.parameters(), lr=args.lr)
    elif args.optimizer == 'SGD':
        optimizer = SGD(model.parameters(), lr=args.lr, momentum=0.9)

    ## loss
    criterion = loss_fn
    ## change LR
    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  factor=0.8,
                                  patience=5,
                                  verbose=True)

    ##########
    ## load previous model status
    previous_valid_loss = 10
    model_path = root / 'model.pt'
    if args.resume_path is not None and model_path.exists():
        state = torch.load(str(model_path))
        epoch = state['epoch']
        step = state['step']
        model.load_state_dict(state['model'])
        epoch = 1
        step = 0
        try:
            previous_valid_loss = state['valid_loss']
        except:
            previous_valid_loss = 10
        print('--' * 10)
        print('Restored previous model, epoch {}, step {:,}'.format(
            epoch, step))
        print('--' * 10)
    else:
        epoch = 1
        step = 0

    #########
    ## start training
    log = root.joinpath('train.log').open('at', encoding='utf8')
    writer = SummaryWriter()
    meter = AllInOneMeter()
    #if previous_valid_loss = 10000
    print('Start training')
    print_model_summay(model)
    previous_valid_jaccard = 0
    for epoch in range(epoch, args.n_epochs + 1):
        model.train()
        random.seed()
        #jaccard = []
        start_time = time.time()
        meter.reset()
        w1 = 1.0
        w2 = 0.5
        w3 = 0.5
        try:
            train_loss = 0
            valid_loss = 0
            # if epoch == 1:
            #     freeze_layer_names = get_freeze_layer_names(part='encoder')
            #     set_freeze_layers(model, freeze_layer_names=freeze_layer_names)
            #     #set_train_layers(model, train_layer_names=['module.final.weight','module.final.bias'])
            #     print_model_summay(model)
            # elif epoch == 5:
            #     w1 = 1.0
            #     w2 = 0.0
            #     w3 = 0.5
            #     freeze_layer_names = get_freeze_layer_names(part='encoder')
            #     set_freeze_layers(model, freeze_layer_names=freeze_layer_names)
            #     # set_train_layers(model, train_layer_names=['module.final.weight','module.final.bias'])
            #     print_model_summay(model)
            #elif epoch == 3:
            #     set_train_layers(model, train_layer_names=['module.dec5.block.0.conv.weight','module.dec5.block.0.conv.bias',
            #                                                'module.dec5.block.1.weight','module.dec5.block.1.bias',
            #                                                'module.dec4.block.0.conv.weight','module.dec4.block.0.conv.bias',
            #                                                'module.dec4.block.1.weight','module.dec4.block.1.bias',
            #                                                'module.dec3.block.0.conv.weight','module.dec3.block.0.conv.bias',
            #                                                'module.dec3.block.1.weight','module.dec3.block.1.bias',
            #                                                'module.dec2.block.0.conv.weight','module.dec2.block.0.conv.bias',
            #                                                'module.dec2.block.1.weight','module.dec2.block.1.bias',
            #                                                'module.dec1.conv.weight','module.dec1.conv.bias',
            #                                                'module.final.weight','module.final.bias'])
            #     print_model_summa zvgf    t5y(model)
            # elif epoch == 50:
            #     set_freeze_layers(model, freeze_layer_names=None)
            #     print_model_summay(model)
            for i, (train_image, train_mask,
                    train_mask_ind) in enumerate(train_loader):
                # inputs, targets = variable(inputs), variable(targets)

                train_image = train_image.permute(0, 3, 1, 2)
                train_mask = train_mask.permute(0, 3, 1, 2)
                train_image = train_image.to(device)
                train_mask = train_mask.to(device).type(torch.cuda.FloatTensor)
                train_mask_ind = train_mask_ind.to(device).type(
                    torch.cuda.FloatTensor)
                # if args.problem_type == 'binary':
                #     train_mask = train_mask.to(device).type(torch.cuda.FloatTensor)
                # else:
                #     #train_mask = train_mask.to(device).type(torch.cuda.LongTensor)
                #     train_mask = train_mask.to(device).type(torch.cuda.FloatTensor)

                outputs, outputs_mask_ind1, outputs_mask_ind2 = model(
                    train_image)
                #print(outputs.size())
                #print(outputs_mask_ind1.size())
                #print(outputs_mask_ind2.size())
                ### note that the last layer in the model is defined differently
                # if args.problem_type == 'binary':
                #     train_prob = F.sigmoid(outputs)
                #     loss = criterion(outputs, train_mask)
                # else:
                #     #train_prob = outputs
                #     train_prob = F.sigmoid(outputs)
                #     loss = torch.tensor(0).type(train_mask.type())
                #     for feat_inx in range(train_mask.shape[1]):
                #         loss += criterion(outputs, train_mask)
                train_prob = F.sigmoid(outputs)
                train_mask_ind_prob1 = F.sigmoid(outputs_mask_ind1)
                train_mask_ind_prob2 = F.sigmoid(outputs_mask_ind2)
                loss1 = criterion(outputs, train_mask)
                #loss1 = F.binary_cross_entropy_with_logits(outputs, train_mask)
                #loss2 = nn.BCEWithLogitsLoss()(outputs_mask_ind1, train_mask_ind)
                #print(train_mask_ind.size())
                #weight = torch.ones_like(train_mask_ind)
                #weight[:, 0] = weight[:, 0] * 1
                #weight[:, 1] = weight[:, 1] * 14
                #weight[:, 2] = weight[:, 2] * 14
                #weight[:, 3] = weight[:, 3] * 4
                #weight[:, 4] = weight[:, 4] * 4
                #weight = weight * train_mask_ind + 1
                #weight = weight.to(device).type(torch.cuda.FloatTensor)
                loss2 = F.binary_cross_entropy_with_logits(
                    outputs_mask_ind1, train_mask_ind)
                loss3 = F.binary_cross_entropy_with_logits(
                    outputs_mask_ind2, train_mask_ind)
                #loss3 = criterion(outputs_mask_ind2, train_mask_ind)
                loss = loss1 * w1 + loss2 * w2 + loss3 * w3
                #print(loss1.item(), loss2.item(), loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                step += 1
                #jaccard += [get_jaccard(train_mask, (train_prob > 0).float()).item()]
                meter.add(train_prob, train_mask, train_mask_ind_prob1,
                          train_mask_ind_prob2, train_mask_ind, loss1.item(),
                          loss2.item(), loss3.item(), loss.item())
                # print(train_mask.data.shape)
                # print(train_mask.data.sum(dim=-2).shape)
                # print(train_mask.data.sum(dim=-2).sum(dim=-1).shape)
                # print(train_mask.data.sum(dim=-2).sum(dim=-1).sum(dim=0).shape)
                # intersection = train_mask.data.sum(dim=-2).sum(dim=-1)
                # print(intersection.shape)
                # print(intersection.dtype)
                # print(train_mask.data.shape[0])
                #torch.zeros([2, 4], dtype=torch.float32)
            #########################
            ## at the end of each epoch, evualte the metrics
            epoch_time = time.time() - start_time
            train_metrics = meter.value()
            train_metrics['epoch_time'] = epoch_time
            train_metrics['image'] = train_image.data
            train_metrics['mask'] = train_mask.data
            train_metrics['prob'] = train_prob.data

            #train_jaccard = np.mean(jaccard)
            #train_auc = str(round(mtr1.value()[0],2))+' '+str(round(mtr2.value()[0],2))+' '+str(round(mtr3.value()[0],2))+' '+str(round(mtr4.value()[0],2))+' '+str(round(mtr5.value()[0],2))
            valid_metrics = valid_fn(model, criterion, valid_loader, device,
                                     num_classes)
            ##############
            ## write events
            write_event(log,
                        step,
                        epoch=epoch,
                        train_metrics=train_metrics,
                        valid_metrics=valid_metrics)
            #save_weights(model, model_path, epoch + 1, step)
            #########################
            ## tensorboard
            write_tensorboard(writer,
                              model,
                              epoch,
                              train_metrics=train_metrics,
                              valid_metrics=valid_metrics)
            #########################
            ## save the best model
            valid_loss = valid_metrics['loss1']
            valid_jaccard = valid_metrics['jaccard']
            if valid_loss < previous_valid_loss:
                save_weights(model, model_path, epoch + 1, step, train_metrics,
                             valid_metrics)
                previous_valid_loss = valid_loss
                print('Save best model by loss')
            if valid_jaccard > previous_valid_jaccard:
                save_weights(model, model_path, epoch + 1, step, train_metrics,
                             valid_metrics)
                previous_valid_jaccard = valid_jaccard
                print('Save best model by jaccard')
            #########################
            ## change learning rate
            scheduler.step(valid_metrics['loss1'])

        except KeyboardInterrupt:
            # print('--' * 10)
            # print('Ctrl+C, saving snapshot')
            # save_weights(model, model_path, epoch, step)
            # print('done.')
            # print('--' * 10)
            writer.close()
            #return
    writer.close()
Esempio n. 15
0
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--jaccard-weight', default=0.5, type=float)
    arg('--device-ids', type=str, default='0',
        help='For example 0,1 to run on two GPUs')
    arg('--filepath', type=str, help='folder with images and annotation masks')
    arg('--root', default='runs/debug', help='checkpoint root')
    arg('--batch-size', type=int, default=32)
    arg('--n-epochs', type=int, default=100)
    arg('--lr', type=float, default=0.0001)
    arg('--workers', type=int, default=12)
    arg('--train_crop_height', type=int, default=416)
    arg('--train_crop_width', type=int, default=416)
    arg('--val_crop_height', type=int, default=416)
    arg('--val_crop_width', type=int, default=416)
    arg('--type', type=str, default='binary', choices=['binary', 'multi'])
    arg('--model', type=str, default='UNet', choices=model_list.keys())
    arg('--datatype', type=str, default='buildings',
        choices=['buildings', 'roads', 'combined'])
    arg('--pretrained', action='store_true',
        help='use pretrained network for initialisation')
    arg('--num_classes', type=int, default=1)

    args = parser.parse_args()

    timestr = time.strftime("%Y%m%d-%H%M%S")

    root = Path(args.root)
    root = Path(os.path.join(root, timestr))
    root.mkdir(exist_ok=True, parents=True)
#    dataset_type = args.filepath.split("/")[-3]
    dataset_type = args.datatype
    print('log', root, dataset_type)
    if not utils.check_crop_size(args.train_crop_height, args.train_crop_width):
        print('Input image sizes should be divisible by 32, but train '
              'crop sizes ({train_crop_height} and {train_crop_width}) '
              'are not.'.format(train_crop_height=args.train_crop_height, train_crop_width=args.train_crop_width))
        sys.exit(0)

    if not utils.check_crop_size(args.val_crop_height, args.val_crop_width):
        print('Input image sizes should be divisible by 32, but validation '
              'crop sizes ({val_crop_height} and {val_crop_width}) '
              'are not.'.format(val_crop_height=args.val_crop_height, val_crop_width=args.val_crop_width))
        sys.exit(0)

    num_classes = args.num_classes

    if args.model == 'UNet':
        model = UNet(num_classes=num_classes)
    else:
        model_name = model_list[args.model]
        model = model_name(num_classes=num_classes, pretrained=args.pretrained)

    if torch.cuda.is_available():
        if args.device_ids:
            device_ids = list(map(int, args.device_ids.split(',')))
        else:
            device_ids = None
        model = nn.DataParallel(model, device_ids=device_ids).cuda()
    else:
        raise SystemError('GPU device not found')

    if args.type == 'binary':
        loss = LossBinary(jaccard_weight=args.jaccard_weight)
    elif args.num_classes == 2:
        labelweights = [89371542, 7083233]
        labelweights = np.sum(labelweights) / \
            (np.multiply(num_classes, labelweights))

        loss = LossMulti(num_classes=num_classes,
                         jaccard_weight=args.jaccard_weight, class_weights=labelweights)

    else:
        #labelweights = [30740321,3046555,1554577]
        #labelweights = labelweights / np.sum(labelweights)
        #labelweights = 1 / np.log(1.2 + labelweights)
        labelweights = [89371542, 29703049, 7083233]
        labelweights = np.sum(labelweights) / \
            (np.multiply(num_classes, labelweights))

        loss = LossMulti(num_classes=num_classes,
                         jaccard_weight=args.jaccard_weight, class_weights=labelweights)

    cudnn.benchmark = True

    train_filename = os.path.join(args.filepath, 'trainval.txt')
    val_filename = os.path.join(args.filepath, 'test.txt')

    def train_transform(p=1):
        return Compose([
            PadIfNeeded(min_height=args.train_crop_height,
                        min_width=args.train_crop_width, p=1),
            RandomCrop(height=args.train_crop_height,
                       width=args.train_crop_width, p=1),
            VerticalFlip(p=0.5),
            HorizontalFlip(p=0.5),
            Normalize(p=1)
        ], p=p)

    def val_transform(p=1):
        return Compose([
            PadIfNeeded(min_height=args.val_crop_height,
                        min_width=args.val_crop_width, p=1),
            CenterCrop(height=args.val_crop_height,
                       width=args.val_crop_width, p=1),
            Normalize(p=1)
        ], p=p)

    train_loader = make_loader(train_filename, shuffle=True, transform=train_transform(
        p=1), problem_type=args.type, batch_size=args.batch_size, datatype=args.datatype)
    valid_loader = make_loader(val_filename, transform=val_transform(p=1), problem_type=args.type,
                               batch_size=len(device_ids), datatype=args.datatype)

    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))
    args.root = root
    if args.type == 'binary':
        valid = validation_binary
    else:
        valid = validation_multi

    utils.train(
        init_optimizer=lambda lr: Adam(model.parameters(), lr=lr),
        args=args,
        model=model,
        criterion=loss,
        train_loader=train_loader,
        valid_loader=valid_loader,
        validation=valid,
        num_classes=num_classes,
        model_name=args.model,
        dataset_type=dataset_type
    )
Esempio n. 16
0
class Trainer:
    @classmethod
    def intersection_over_union(cls, y, z):
        iou = (torch.sum(torch.min(y, z))) / (torch.sum(torch.max(y, z)))
        return iou

    @classmethod
    def get_number_of_batches(cls, image_paths, batch_size):
        batches = len(image_paths) / batch_size
        if not batches.is_integer():
            batches = math.floor(batches) + 1
        return int(batches)

    @classmethod
    def evaluate_loss(cls, criterion, output, target):
        loss_1 = criterion(output, target)
        loss_2 = 1 - Trainer.intersection_over_union(output, target)
        loss = loss_1 + 0.1 * loss_2
        return loss

    def __init__(self, side_length, batch_size, epochs, learning_rate,
                 momentum_parameter, seed, image_paths, state_dict,
                 train_val_split):
        self.side_length = side_length
        self.batch_size = batch_size
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.momentum_parameter = momentum_parameter
        self.seed = seed
        self.image_paths = glob.glob(image_paths)
        self.batches = Trainer.get_number_of_batches(self.image_paths,
                                                     self.batch_size)
        self.model = UNet()
        self.loader = Loader(self.side_length)
        self.state_dict = state_dict
        self.train_val_split = train_val_split
        self.train_size = int(np.floor((self.train_val_split * self.batches)))

    def set_cuda(self):
        if torch.cuda.is_available():
            self.model = self.model.cuda()

    def set_seed(self):
        if self.seed is not None:
            np.random.seed(self.seed)

    def process_batch(self, batch):
        # Grab a batch, shuffled according to the provided seed. Note that
        # i-th image: samples[i][0], i-th mask: samples[i][1]
        samples = Loader.get_batch(self.image_paths, self.batch_size, batch,
                                   self.seed)
        samples.astype(float)
        # Cast samples into torch.FloatTensor for interaction with U-Net
        samples = torch.from_numpy(samples)
        samples = samples.float()

        # Cast into a CUDA tensor, if GPUs are available
        if torch.cuda.is_available():
            samples = samples.cuda()

        # Isolate images and their masks
        samples_images = samples[:, 0]
        samples_masks = samples[:, 1]

        # Reshape for interaction with U-Net
        samples_images = samples_images.unsqueeze(1)
        samples_masks = samples_masks.unsqueeze(1)

        # Run inputs through the model
        output = self.model(samples_images)

        # Clamp the target for proper interaction with BCELoss
        target = torch.clamp(samples_masks, min=0, max=1)

        del samples

        return output, target

    def train_model(self):
        self.model.train()
        criterion = nn.BCELoss()
        optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)

        iteration = 0
        best_iteration = 0
        best_loss = 10**10

        losses_train = []
        losses_val = []

        iou_train = []
        average_iou_train = []
        iou_val = []
        average_iou_val = []

        print("BEGIN TRAINING")
        print("TRAINING BATCHES:", self.train_size)
        print("VALIDATION BATCHES:", self.batches - self.train_size)
        print("BATCH SIZE:", self.batch_size)
        print("EPOCHS:", self.epochs)
        print("~~~~~~~~~~~~~~~~~~~~~~~~~~")

        for k in range(0, self.epochs):
            print("EPOCH:", k + 1)
            print("~~~~~~~~~~~~~~~~~~~~~~~~~~")

            # Train
            for batch in range(0, self.train_size):
                iteration = iteration + 1
                output, target = self.process_batch(batch)
                loss = Trainer.evaluate_loss(criterion, output, target)
                print("EPOCH:", self.epochs)
                print("Batch", batch, "of", self.train_size)

                # Aggregate intersection over union scores for each element in the batch
                for i in range(0, output.shape[0]):
                    binary_mask = Editor.make_binary_mask_from_torch(
                        output[i, :, :, :], 1.0)
                    iou = Trainer.intersection_over_union(
                        binary_mask, target[i, :, :, :].cpu())
                    iou_train.append(iou.item())
                    print("IoU:", iou.item())

                # Clear data to prevent memory overload
                del target
                del output

                # Clear gradients, back-propagate, and update weights
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Record the loss value
                loss_value = loss.item()
                if best_loss > loss_value:
                    best_loss = loss_value
                    best_iteration = iteration
                losses_train.append(loss_value)

                if batch == self.train_size - 1:
                    print("LOSS:", loss_value)
                    print("~~~~~~~~~~~~~~~~~~~~~~~~~~")

            average_iou = sum(iou_train) / len(iou_train)
            print("Average IoU:", average_iou)
            average_iou_train.append(average_iou)
            #Visualizer.save_loss_plot(average_iou_train, "average_iou_train.png")

            # Validate
            for batch in range(self.train_size, self.batches):
                output, target = self.process_batch(batch)
                loss = Trainer.evaluate_loss(criterion, output, target)

                for i in range(0, output.shape[0]):
                    binary_mask = Editor.make_binary_mask_from_torch(
                        output[i, :, :, :], 1.0)
                    iou = Trainer.intersection_over_union(
                        binary_mask, target[i, :, :, :].cpu())
                    iou_val.append(iou.item())
                    print("IoU:", iou.item())

                loss_value = loss.item()
                losses_val.append(loss_value)
                print("EPOCH:", self.epochs)
                print("VALIDATION LOSS:", loss_value)
                print("~~~~~~~~~~~~~~~~~~~~~~~~~~")
                del output
                del target

            average_iou = sum(iou_val) / len(iou_val)
            print("Average IoU:", average_iou)
            average_iou_val.append(average_iou)
            #Visualizer.save_loss_plot(average_iou_val, "average_iou_val.png")

        print("Least loss", best_loss, "at iteration", best_iteration)

        torch.save(self.model.state_dict(), "weights/" + self.state_dict)
Esempio n. 17
0
def semi_supervised_segmentation(mat,
                                 cor_map=None,
                                 model=None,
                                 out_channels=[8, 16, 32],
                                 kernel_size=3,
                                 frames_per_iter=100,
                                 num_iters=200,
                                 print_every=1,
                                 select_frames=False,
                                 return_model=False,
                                 optimizer_fn=torch.optim.AdamW,
                                 optimizer_fn_args={
                                     'lr': 1e-2,
                                     'weight_decay': 1e-3
                                 },
                                 loss_threshold=0,
                                 save_loss_folder=None,
                                 reduction='max',
                                 last_out_channels=None,
                                 verbose=False,
                                 device=torch.device('cuda')):
    """Semi-supervised semantic segmentation
    Args: 
        mat: torch.Tensor
    
    Returns:
        soft_mask: 2D torch.Tensor
        model: learned model if return_model is True
    """
    if cor_map is None:
        cor_map = get_cor_map_4d(mat, select_frames=select_frames)
    high_conf_mask = get_high_conf_mask(cor_map)
    loss_fn = nn.CrossEntropyLoss(
        weight=torch.tensor([(high_conf_mask == 1).sum(), (
            high_conf_mask == 0).sum()]).float().to(device))
    loss_history = []
    if model is None:
        #         nrow, ncol = mat.shape[-2:]
        #         pool_kernel_size_row = get_prime_factors(nrow)[:3]
        #         pool_kernel_size_col = get_prime_factors(ncol)[:3]
        #         model = UNet(in_channels=mat.shape[0], num_classes=2, out_channels=out_channels, num_conv=2, n_dim=3,
        #                      kernel_size=[3, (3, 3, 3), (3, 3, 3), (3, 3, 3), (1, 3, 3), (1, 3, 3), (1, 3, 3)],
        #                      padding=[1, (0, 1, 1), (0, 1, 1), (0, 1, 1), (0, 1, 1), (0, 1, 1), (0, 1, 1)],
        #                      pool_kernel_size=[(2, pool_kernel_size_row[0], pool_kernel_size_col[0]),
        #                                        (2, pool_kernel_size_row[1], pool_kernel_size_col[1]),
        #                                        (2, pool_kernel_size_row[2], pool_kernel_size_col[2])],
        #                      use_adaptive_pooling=True, same_shape=False,
        #                      transpose_kernel_size=[(1, pool_kernel_size_row[2], pool_kernel_size_col[2]),
        #                                             (1, pool_kernel_size_row[1], pool_kernel_size_col[1]),
        #                                             (1, pool_kernel_size_row[0], pool_kernel_size_col[0])],
        #                      transpose_stride=[(1, pool_kernel_size_row[2], pool_kernel_size_col[2]),
        #                                        (1, pool_kernel_size_row[1], pool_kernel_size_col[1]),
        #                                        (1, pool_kernel_size_row[0], pool_kernel_size_col[0])],
        #                      padding_mode='zeros', normalization='layer_norm',
        #                      activation=nn.LeakyReLU(negative_slope=0.01, inplace=True)).to(device)
        if mat.ndim == 3:
            mat = mat.unsqueeze(0)
        in_channels = mat.shape[0]
        num_classes = 2
        if isinstance(kernel_size, int):
            padding = (kernel_size - 1) // 2
            padding_row = padding_col = padding
        elif isinstance(kernel_size, tuple):
            assert len(kernel_size) == 3
            padding, padding_row, padding_col = [(k - 1) // 2
                                                 for k in kernel_size]
        encoder_depth = len(out_channels)
        nframe, nrow, ncol = mat.shape[-3:]
        if isinstance(kernel_size, int):
            assert nframe > 4 * encoder_depth * (kernel_size - 1)
        else:
            assert nframe > 4 * encoder_depth * (kernel_size[0] - 1)
        pool_kernel_size_row = get_prime_factors(nrow)[:encoder_depth]
        pool_kernel_size_col = get_prime_factors(ncol)[:encoder_depth]
        model = UNet(
            in_channels=in_channels,
            num_classes=num_classes,
            out_channels=out_channels,
            num_conv=2,
            n_dim=3,
            kernel_size=kernel_size,
            padding=[padding] + [(0, padding, padding)] * encoder_depth * 2,
            pool_kernel_size=[(1, pool_kernel_size_row[i],
                               pool_kernel_size_col[i])
                              for i in range(encoder_depth)],
            use_adaptive_pooling=True,
            same_shape=False,
            transpose_kernel_size=[(1, pool_kernel_size_row[i],
                                    pool_kernel_size_col[i])
                                   for i in reversed(range(encoder_depth))],
            transpose_stride=[(1, pool_kernel_size_row[i],
                               pool_kernel_size_col[i])
                              for i in reversed(range(encoder_depth))],
            padding_mode='zeros',
            normalization='layer_norm',
            last_out_channels=last_out_channels,
            activation=nn.LeakyReLU(negative_slope=0.01,
                                    inplace=True)).to(device)
    optimizer = optimizer_fn(
        filter(lambda p: p.requires_grad, model.parameters()),
        **optimizer_fn_args)

    idx = torch.nonzero(high_conf_mask != -1, as_tuple=True)
    y_true = high_conf_mask[idx].long()
    for i in range(num_iters):
        #         if (i+1) == num_iters//2:
        #             optimizer_fn_args['lr'] /= 10
        #             optimizer = optimizer_fn(filter(lambda p: p.requires_grad, model.parameters()), **optimizer_fn_args)
        x = get_tensor_slice(mat, dims=[1], sizes=[frames_per_iter])
        y_pred = model(x)
        if reduction == 'max':
            y_pred = y_pred.max(1)[0]
        elif reduction == 'mean':
            y_pred = y_pred.mean(1)
        elif reduction.startswith('top'):
            n = y_pred.size(1)
            if reduction.endswith('percent'):
                k = max(int(int(reduction[3:-7]) / 100. * n), 1)
            else:
                k = min(int(reduction[3:]), n)
            y_pred = y_pred.topk(k, dim=1)[0].mean(1)
        else:
            raise ValueError(f'reduction = {reduction} not handled!')
        y_pred = y_pred[:, idx[0], idx[1]].T
        loss = loss_fn(y_pred, y_true)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_history.append(loss.item())
        if verbose and ((i + 1) % print_every == 0 or i == 0
                        or i == num_iters - 1):
            print(f'{i+1} loss={loss.item()}')
        if loss_threshold > 0 and (i + 1) % print_every == 0 and np.mean(
                loss_history[-print_every:]) < loss_threshold:
            break
    if verbose:
        plt.title('Training loss')
        plt.plot(loss_history, 'ro-', markersize=2)
        plt.xlabel('Iteration')
        plt.ylabel('Loss')
        plt.show()
    if save_loss_folder is not None and os.path.exists(save_loss_folder):
        np.save(f'{save_loss_folder}/loss__semi_supervised_segmentation.npy',
                loss_history)
    with torch.no_grad():
        soft_masks = []
        num_iters = mat.shape[1] // frames_per_iter + 1
        for i in range(num_iters):
            x = get_tensor_slice(mat, dims=[1], sizes=[frames_per_iter])
            y_pred = model(x)
            if reduction == 'max':
                y_pred = y_pred.max(1)[0]
            elif reduction == 'mean':
                y_pred = y_pred.mean(1)
            elif reduction.startswith('top'):
                n = y_pred.size(1)
                if reduction.endswith('percent'):
                    k = max(int(int(reduction[3:-7]) / 100. * n), 1)
                else:
                    k = min(int(reduction[3:]), n)
                y_pred = y_pred.topk(k, dim=1)[0].mean(1)
            else:
                raise ValueError(f'reduction = {reduction} not handled!')
            soft_mask = torch.softmax(y_pred, dim=0)[1]
            soft_masks.append(soft_mask)
        soft_mask = torch.stack(soft_masks, dim=0).mean(0)
    if return_model:
        return soft_mask, model
    else:
        return soft_mask
Esempio n. 18
0
                              num_workers=1,
                              shuffle=True,
                              drop_last=True)
    eval_loader = DataLoader(evalset,
                             batch_size=batch,
                             num_workers=1,
                             shuffle=True,
                             drop_last=True)

    model = UNet(n_channels=1, n_classes=1)
    model.to(device=device)

    # criterion = nn.CrossEntropyLoss()
    criterion = nn.MSELoss()
    # criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    training_loss = []
    eval_loss = []

    for epoch in range(num_epoch):
        train()
        evaluate()
        plot_losses()
        if (epoch % 10) == 0:
            torch.save(model.state_dict(),
                       os.path.join(models_path, f"unet_{attempt}_{epoch}.pt"))
        else:
            torch.save(model.state_dict(),
                       os.path.join(models_path, f"unet_{attempt}.pt"))
    print("Done!")
Esempio n. 19
0
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--jaccard-weight', default=1, type=float)
    arg('--device-ids',
        type=str,
        default='0',
        help='For example 0,1 to run on two GPUs')
    arg('--fold', type=int, help='fold', default=0)
    arg('--root', default='runs/debug', help='checkpoint root')
    arg('--batch-size', type=int, default=1)
    arg('--n-epochs', type=int, default=10)
    arg('--lr', type=float, default=0.0002)
    arg('--workers', type=int, default=10)
    arg('--type',
        type=str,
        default='binary',
        choices=['binary', 'parts', 'instruments'])
    arg('--model',
        type=str,
        default='DLinkNet',
        choices=['UNet', 'UNet11', 'LinkNet34', 'DLinkNet'])

    args = parser.parse_args()

    root = Path(args.root)
    root.mkdir(exist_ok=True, parents=True)

    if args.type == 'parts':
        num_classes = 4
    elif args.type == 'instruments':
        num_classes = 8
    else:
        num_classes = 1

    if args.model == 'UNet':
        model = UNet(num_classes=num_classes)
    elif args.model == 'UNet11':
        model = UNet11(num_classes=num_classes, pretrained='vgg')
    elif args.model == 'UNet16':
        model = UNet16(num_classes=num_classes, pretrained='vgg')
    elif args.model == 'LinkNet34':
        model = LinkNet34(num_classes=num_classes, pretrained=True)
    elif args.model == 'DLinkNet':
        model = D_LinkNet34(num_classes=num_classes, pretrained=True)
    else:
        model = UNet(num_classes=num_classes, input_channels=3)

    if torch.cuda.is_available():
        if args.device_ids:
            device_ids = list(map(int, args.device_ids.split(',')))
        else:
            device_ids = None
        model = nn.DataParallel(model, device_ids=device_ids).cuda()

    if args.type == 'binary':
        # loss = LossBinary(jaccard_weight=args.jaccard_weight)
        loss = LossBCE_DICE()
    else:
        loss = LossMulti(num_classes=num_classes,
                         jaccard_weight=args.jaccard_weight)

    cudnn.benchmark = True

    def make_loader(file_names,
                    shuffle=False,
                    transform=None,
                    problem_type='binary'):
        return DataLoader(dataset=RoboticsDataset(file_names,
                                                  transform=transform,
                                                  problem_type=problem_type),
                          shuffle=shuffle,
                          num_workers=args.workers,
                          batch_size=args.batch_size,
                          pin_memory=torch.cuda.is_available())

    # train_file_names, val_file_names = get_split(args.fold)
    train_file_names, val_file_names = get_train_val_files()

    print('num train = {}, num_val = {}'.format(len(train_file_names),
                                                len(val_file_names)))

    train_transform = DualCompose(
        [HorizontalFlip(),
         VerticalFlip(),
         ImageOnly(Normalize())])

    val_transform = DualCompose([ImageOnly(Normalize())])

    train_loader = make_loader(train_file_names,
                               shuffle=True,
                               transform=train_transform,
                               problem_type=args.type)
    valid_loader = make_loader(val_file_names,
                               transform=val_transform,
                               problem_type=args.type)

    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))

    if args.type == 'binary':
        valid = validation_binary
    else:
        valid = validation_multi

    utils.train(init_optimizer=lambda lr: Adam(model.parameters(), lr=lr),
                args=args,
                model=model,
                criterion=loss,
                train_loader=train_loader,
                valid_loader=valid_loader,
                validation=valid,
                fold=args.fold,
                num_classes=num_classes)
Esempio n. 20
0
def train_unet(epoch=100):

    # Get all images in train set
    image_names = os.listdir('dataset/train/images/')
    image_names = [name for name in image_names if name.endswith(('.jpg', '.JPG', '.png'))]

    # Split into train and validation sets
    np.random.shuffle(image_names)
    split = int(len(image_names) * 0.9)
    train_image_names = image_names[:split]
    val_image_names = image_names[split:]

    # Create a dataset
    train_dataset = EggsPansDataset('dataset/train', train_image_names, mode='train')
    val_dataset = EggsPansDataset('dataset/train', val_image_names, mode='val')

    # Create a dataloader
    train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=False, num_workers=0)
    val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0)

    # Initialize model and transfer to device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = UNet()
    model = model.to(device)
    optim = torch.optim.Adam(model.parameters(), lr=0.0001)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='max', verbose=True)
    loss_obj = EggsPansLoss()
    metrics_obj = EggsPansMetricIoU()

    # Keep best IoU and checkpoint
    best_iou = 0.0

    # Train epochs
    for epoch_idx in range(epoch):

        print('Epoch: {:2}/{}'.format(epoch_idx + 1, epoch))
        # Reset metrics and loss
        loss_obj.reset_loss()
        metrics_obj.reset_iou()

        # Train phase
        model.train()

        # Train epoch
        pbar = tqdm(train_dataloader)
        for imgs, egg_masks, pan_masks in pbar:

            # Convert to device
            imgs = imgs.to(device)
            gt_egg_masks = egg_masks.to(device)
            gt_pan_masks = pan_masks.to(device)

            # Zero gradients
            optim.zero_grad()

            # Forward through net, and get the loss
            pred_egg_masks, pred_pan_masks = model(imgs)

            loss = loss_obj([gt_egg_masks, gt_pan_masks], [pred_egg_masks, pred_pan_masks])
            iou = metrics_obj([gt_egg_masks, gt_pan_masks], [pred_egg_masks, pred_pan_masks])

            # Compute gradients and compute them
            loss.backward()
            optim.step()

            # Update metrics
            pbar.set_description('Loss: {:5.6f}, IoU: {:5.6f}'.format(loss_obj.get_running_loss(),
                                                                      metrics_obj.get_running_iou()))

        print('Validation: ')

        # Reset metrics and loss
        loss_obj.reset_loss()
        metrics_obj.reset_iou()

        # Val phase
        model.eval()

        # Val epoch
        pbar = tqdm(val_dataloader)
        for imgs, egg_masks, pan_masks in pbar:

            # Convert to device
            imgs = imgs.to(device)
            gt_egg_masks = egg_masks.to(device)
            gt_pan_masks = pan_masks.to(device)

            with torch.no_grad():
                # Forward through net, and get the loss
                pred_egg_masks, pred_pan_masks = model(imgs)

                loss = loss_obj([gt_egg_masks, gt_pan_masks], [pred_egg_masks, pred_pan_masks])
                iou = metrics_obj([gt_egg_masks, gt_pan_masks], [pred_egg_masks, pred_pan_masks])

            pbar.set_description('Val Loss: {:5.6f}, IoU: {:5.6f}'.format(loss_obj.get_running_loss(),
                                                                          metrics_obj.get_running_iou()))

        # Save best model
        if best_iou < metrics_obj.get_running_iou():
            best_iou = metrics_obj.get_running_iou()
            torch.save(model.state_dict(), os.path.join('checkpoints/', 'epoch_{}_{:.4f}.pth'.format(
                epoch_idx + 1, metrics_obj.get_running_iou())))

        # Reduce learning rate on plateau
        lr_scheduler.step(metrics_obj.get_running_iou())

        print('\n')
        print('-'*100)
Esempio n. 21
0
                      shuffle=False,
                      num_workers=num_workers)

# Net model
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device(
    'cpu')
if args.MODEL == 'unet':
    net = UNet(in_ch=1,
               out_ch=N_LANDMARKS,
               down_drop=args.DOWN_DROP,
               up_drop=args.UP_DROP)
net.to(device)

# Optimizer + loss
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(),
                             lr=args.LEARN_RATE,
                             weight_decay=args.WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                 patience=args.OPTIM_PATIENCE,
                                                 verbose=True)


def train():
    train_loss, train_mre, train_sdr_4mm = 0, 0, 0
    train_examples = 0
    net.train()
    for imgs, true_heatmaps, _ in train_dl:
        imgs = imgs.to(device)
        true_heatmaps = true_heatmaps.to(device)
Esempio n. 22
0
def train_val(config):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    train_loader = get_dataloader(img_dir=config.train_img_dir,
                                  mask_dir=config.train_mask_dir,
                                  mode="train",
                                  batch_size=config.batch_size,
                                  num_workers=config.num_workers)
    val_loader = get_dataloader(img_dir=config.val_img_dir,
                                mask_dir=config.val_mask_dir,
                                mode="val",
                                batch_size=config.batch_size,
                                num_workers=config.num_workers)

    writer = SummaryWriter(
        comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" %
        (config.lr, config.batch_size, config.model_type, config.data_type))

    if config.model_type not in [
            'UNet', 'R2UNet', 'AUNet', 'R2AUNet', 'SEUNet', 'SEUNet++',
            'UNet++', 'DAUNet', 'DANet', 'AUNetR', 'RendDANet', "BASNet"
    ]:
        print('ERROR!! model_type should be selected in supported models')
        print('Choose model %s' % config.model_type)
        return
    if config.model_type == "UNet":
        model = UNet()
    elif config.model_type == "AUNet":
        model = AUNet()
    elif config.model_type == "R2UNet":
        model = R2UNet()
    elif config.model_type == "SEUNet":
        model = SEUNet(useCSE=False, useSSE=False, useCSSE=True)
    elif config.model_type == "UNet++":
        model = UNetPP()
    elif config.model_type == "DANet":
        model = DANet(backbone='resnet101', nclass=1)
    elif config.model_type == "AUNetR":
        model = AUNet_R16(n_classes=1, learned_bilinear=True)
    elif config.model_type == "RendDANet":
        model = RendDANet(backbone='resnet101', nclass=1)
    elif config.model_type == "BASNet":
        model = BASNet(n_channels=3, n_classes=1)
    else:
        model = UNet()

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.to(device, dtype=torch.float)

    if config.optimizer == "sgd":
        optimizer = SGD(model.parameters(),
                        lr=config.lr,
                        weight_decay=1e-6,
                        momentum=0.9)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    if config.loss == "dice":
        criterion = DiceLoss()
    elif config.loss == "bce":
        criterion = nn.BCELoss()
    elif config.loss == "bas":
        criterion = BasLoss()
    else:
        criterion = MixLoss()

    scheduler = lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)
    global_step = 0
    best_dice = 0.0
    for epoch in range(config.num_epochs):
        epoch_loss = 0.0
        with tqdm(total=config.num_train,
                  desc="Epoch %d / %d" % (epoch + 1, config.num_epochs),
                  unit='img') as train_pbar:
            model.train()
            for image, mask in train_loader:
                image = image.to(device, dtype=torch.float)
                mask = mask.to(device, dtype=torch.float)
                d0, d1, d2, d3, d4, d5, d6, d7 = model(image)
                loss = criterion(d0, d1, d2, d3, d4, d5, d6, d7, mask)
                epoch_loss += loss.item()

                writer.add_scalar('Loss/train', loss.item(), global_step)
                train_pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_pbar.update(image.shape[0])
                global_step += 1

                # if global_step % 100 == 0:
                #     writer.add_images('masks/true', mask, global_step)
                #     writer.add_images('masks/pred', d0 > 0.5, global_step)
            scheduler.step()
        epoch_dice = 0.0
        epoch_acc = 0.0
        epoch_sen = 0.0
        epoch_spe = 0.0
        epoch_pre = 0.0
        current_num = 0
        with tqdm(total=config.num_val,
                  desc="Epoch %d / %d validation round" %
                  (epoch + 1, config.num_epochs),
                  unit='img') as val_pbar:
            model.eval()
            locker = 0
            for image, mask in val_loader:
                current_num += image.shape[0]
                image = image.to(device, dtype=torch.float)
                mask = mask.to(device, dtype=torch.float)
                d0, d1, d2, d3, d4, d5, d6, d7 = model(image)
                batch_dice = dice_coeff(mask, d0).item()
                epoch_dice += batch_dice * image.shape[0]
                epoch_acc += get_accuracy(pred=d0, true=mask) * image.shape[0]
                epoch_sen += get_sensitivity(pred=d0,
                                             true=mask) * image.shape[0]
                epoch_spe += get_specificity(pred=d0,
                                             true=mask) * image.shape[0]
                epoch_pre += get_precision(pred=d0, true=mask) * image.shape[0]
                if locker == 200:
                    writer.add_images('masks/true', mask, epoch + 1)
                    writer.add_images('masks/pred', d0 > 0.5, epoch + 1)
                val_pbar.set_postfix(**{'dice (batch)': batch_dice})
                val_pbar.update(image.shape[0])
                locker += 1
            epoch_dice /= float(current_num)
            epoch_acc /= float(current_num)
            epoch_sen /= float(current_num)
            epoch_spe /= float(current_num)
            epoch_pre /= float(current_num)
            epoch_f1 = get_F1(SE=epoch_sen, PR=epoch_pre)
            if epoch_dice > best_dice:
                best_dice = epoch_dice
                writer.add_scalar('Best Dice/test', best_dice, epoch + 1)
                torch.save(
                    model, config.result_path + "/%s_%s_%d.pth" %
                    (config.model_type, str(epoch_dice), epoch + 1))
            logging.info('Validation Dice Coeff: {}'.format(epoch_dice))
            print("epoch dice: " + str(epoch_dice))
            writer.add_scalar('Dice/test', epoch_dice, epoch + 1)
            writer.add_scalar('Acc/test', epoch_acc, epoch + 1)
            writer.add_scalar('Sen/test', epoch_sen, epoch + 1)
            writer.add_scalar('Spe/test', epoch_spe, epoch + 1)
            writer.add_scalar('Pre/test', epoch_pre, epoch + 1)
            writer.add_scalar('F1/test', epoch_f1, epoch + 1)

    writer.close()
    print("Training finished")
Esempio n. 23
0
def train(device, gen_model, disc_model, real_dataset_path, epochs):
    """trains a gan"""
    train_transform = tv.transforms.Compose([
        tv.transforms.Resize((224, 224)),
        tv.transforms.RandomHorizontalFlip(0.5),
        tv.transforms.RandomVerticalFlip(0.5),
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, ), (0.5, ))
    ])

    realdataset = ColorDataset(real_dataset_path, transform=train_transform)
    realloader = torch.utils.data.DataLoader(realdataset,
                                             batch_size=20,
                                             shuffle=True,
                                             num_workers=cpu_count(),
                                             drop_last=True)
    realiter = iter(realloader)

    discriminator = discriminator_model(3, 1024).to(device)
    disc_optimizer = torch.optim.Adam(discriminator.parameters(),
                                      lr=0.0001,
                                      betas=(0, 0.9))
    if os.path.exists(disc_model):
        discriminator.load_state_dict(torch.load(disc_model))

    generator = UNet(1, 3).to(device)
    gen_optimizer = torch.optim.Adam(generator.parameters(),
                                     lr=0.0001,
                                     betas=(0, 0.9))
    if os.path.exists(gen_model):
        generator.load_state_dict(torch.load(gen_model))

    one = torch.FloatTensor([1])
    mone = one * -1
    one = one.to(device).squeeze()
    mone = mone.to(device).squeeze()

    n_critic = 5
    lam = 10
    for _ in tqdm.trange(epochs, desc="Epochs"):
        for param in discriminator.parameters():
            param.requires_grad = True

        for _ in range(n_critic):
            real_data, realiter = try_iter(realiter, realloader)
            real_data = real_data.to(device)

            disc_optimizer.zero_grad()

            disc_real = discriminator(real_data)
            real_cost = torch.mean(disc_real)
            real_cost.backward(mone)

            # fake_data, fakeiter = try_iter(fakeiter, fakeloader)
            fake_data = torch.randn(real_data.shape[0], 1, 224, 224)
            fake_data = fake_data.to(device)
            disc_fake = discriminator(generator(fake_data))
            fake_cost = torch.mean(disc_fake)
            fake_cost.backward(one)

            gradient_penalty = calc_gp(device, discriminator, real_data,
                                       fake_data, lam)
            gradient_penalty.backward()

            disc_optimizer.step()
        for param in discriminator.parameters():
            param.requires_grad = False
        gen_optimizer.zero_grad()

        # fake_data, fakeiter = try_iter(fakeiter, fakeloader)
        fake_data = torch.randn(real_data.shape[0], 1, 224, 224)
        fake_data = fake_data.to(device)
        disc_g = discriminator(generator(fake_data)).mean()
        disc_g.backward(mone)
        gen_optimizer.step()

        torch.save(generator.state_dict(), gen_model)
        torch.save(discriminator.state_dict(), disc_model)
Esempio n. 24
0
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--jaccard-weight', default=0.5, type=float)
    arg('--device-ids',
        type=str,
        default='0',
        help='For example 0,1 to run on two GPUs')
    arg('--fold', type=int, help='fold', default=0)
    arg('--root', default='runs/debug', help='checkpoint root')
    arg('--batch-size', type=int, default=1)
    arg('--n-epochs', type=int, default=100)
    arg('--lr', type=float, default=0.0001)
    arg('--workers', type=int, default=12)
    arg('--train_crop_height', type=int, default=1024)
    arg('--train_crop_width', type=int, default=1280)
    arg('--val_crop_height', type=int, default=1024)
    arg('--val_crop_width', type=int, default=1280)
    arg('--type',
        type=str,
        default='binary',
        choices=['binary', 'parts', 'instruments'])
    arg('--model', type=str, default='UNet', choices=moddel_list.keys())

    args = parser.parse_args()

    root = Path(args.root)
    root.mkdir(exist_ok=True, parents=True)

    if not utils.check_crop_size(args.train_crop_height,
                                 args.train_crop_width):
        print('Input image sizes should be divisible by 32, but train '
              'crop sizes ({train_crop_height} and {train_crop_width}) '
              'are not.'.format(train_crop_height=args.train_crop_height,
                                train_crop_width=args.train_crop_width))
        sys.exit(0)

    if not utils.check_crop_size(args.val_crop_height, args.val_crop_width):
        print('Input image sizes should be divisible by 32, but validation '
              'crop sizes ({val_crop_height} and {val_crop_width}) '
              'are not.'.format(val_crop_height=args.val_crop_height,
                                val_crop_width=args.val_crop_width))
        sys.exit(0)

    if args.type == 'parts':
        num_classes = 4
    elif args.type == 'instruments':
        num_classes = 8
    else:
        num_classes = 1

    if args.model == 'UNet':
        model = UNet(num_classes=num_classes)
    else:
        model_name = moddel_list[args.model]
        model = model_name(num_classes=num_classes, pretrained=True)

    if torch.cuda.is_available():
        if args.device_ids:
            device_ids = list(map(int, args.device_ids.split(',')))
        else:
            device_ids = None
        model = nn.DataParallel(model, device_ids=device_ids).cuda()
    else:
        raise SystemError('GPU device not found')

    if args.type == 'binary':
        loss = LossBinary(jaccard_weight=args.jaccard_weight)

    else:
        loss = LossMulti(num_classes=num_classes,
                         jaccard_weight=args.jaccard_weight)

    cudnn.benchmark = True

    def make_loader(file_names,
                    shuffle=False,
                    transform=None,
                    problem_type='binary',
                    batch_size=1):
        return DataLoader(dataset=RoboticsDataset(file_names,
                                                  transform=transform,
                                                  problem_type=problem_type),
                          shuffle=shuffle,
                          num_workers=args.workers,
                          batch_size=batch_size,
                          pin_memory=torch.cuda.is_available())

    #print('sfsdgsdhsfffffffffff',args.fold)
    train_file_names, val_file_names = get_split(args.fold)

    print('num train = {}, num_val = {}'.format(len(train_file_names),
                                                len(val_file_names)))

    def train_transform(p=1):
        return Compose([
            PadIfNeeded(min_height=args.train_crop_height,
                        min_width=args.train_crop_width,
                        p=1),
            RandomCrop(height=args.train_crop_height,
                       width=args.train_crop_width,
                       p=1),
            VerticalFlip(p=0.5),
            HorizontalFlip(p=0.5),
            Normalize(p=1)
        ],
                       p=p)

    def val_transform(p=1):
        return Compose([
            PadIfNeeded(min_height=args.val_crop_height,
                        min_width=args.val_crop_width,
                        p=1),
            CenterCrop(
                height=args.val_crop_height, width=args.val_crop_width, p=1),
            Normalize(p=1)
        ],
                       p=p)

    train_loader = make_loader(train_file_names,
                               shuffle=True,
                               transform=train_transform(p=1),
                               problem_type=args.type,
                               batch_size=args.batch_size)
    valid_loader = make_loader(val_file_names,
                               transform=val_transform(p=1),
                               problem_type=args.type,
                               batch_size=len(device_ids))

    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))

    if args.type == 'binary':
        valid = validation_binary
    else:
        valid = validation_multi

    print(model.parameters())
    utils.train(init_optimizer=lambda lr: Adam(model.parameters(), lr=lr),
                args=args,
                model=model,
                criterion=loss,
                train_loader=train_loader,
                valid_loader=valid_loader,
                validation=valid,
                fold=args.fold,
                num_classes=num_classes)
Esempio n. 25
0
def train_val(config):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    train_loader = get_dataloader(img_dir=config.train_img_dir,
                                  mask_dir=config.train_mask_dir,
                                  mode="train",
                                  batch_size=config.batch_size,
                                  num_workers=config.num_workers,
                                  smooth=config.smooth)
    val_loader = get_dataloader(img_dir=config.val_img_dir,
                                mask_dir=config.val_mask_dir,
                                mode="val",
                                batch_size=config.batch_size,
                                num_workers=config.num_workers)

    writer = SummaryWriter(
        comment="LR_%f_BS_%d_MODEL_%s_DATA_%s" %
        (config.lr, config.batch_size, config.model_type, config.data_type))

    if config.model_type == "UNet":
        model = UNet()
    elif config.model_type == "UNet++":
        model = UNetPP()
    elif config.model_type == "SEDANet":
        model = SEDANet()
    elif config.model_type == "RefineNet":
        model = rf101()
    elif config.model_type == "DANet":
        # src = "./pretrained/60_DANet_0.8086.pth"
        # pretrained_dict = torch.load(src, map_location='cpu').module.state_dict()
        # print("load pretrained params from stage 1: " + src)
        # pretrained_dict.pop('seg1.1.weight')
        # pretrained_dict.pop('seg1.1.bias')
        model = DANet(backbone='resnext101',
                      nclass=config.output_ch,
                      pretrained=True,
                      norm_layer=nn.BatchNorm2d)
        # model_dict = model.state_dict()
        # model_dict.update(pretrained_dict)
        # model.load_state_dict(model_dict)
    elif config.model_type == "Deeplabv3+":
        # src = "./pretrained/Deeplabv3+.pth"
        # pretrained_dict = torch.load(src, map_location='cpu').module.state_dict()
        # print("load pretrained params from stage 1: " + src)
        # # print(pretrained_dict.keys())
        # for key in list(pretrained_dict.keys()):
        #     if key.split('.')[0] == "cbr_last":
        #         pretrained_dict.pop(key)
        model = deeplabv3_plus.DeepLabv3_plus(in_channels=3,
                                              num_classes=config.output_ch,
                                              backend='resnet101',
                                              os=16,
                                              pretrained=True,
                                              norm_layer=nn.BatchNorm2d)
        # model_dict = model.state_dict()
        # model_dict.update(pretrained_dict)
        # model.load_state_dict(model_dict)
    elif config.model_type == "HRNet_OCR":
        model = seg_hrnet_ocr.get_seg_model()
    elif config.model_type == "scSEUNet":
        model = scSEUNet(pretrained=True, norm_layer=nn.BatchNorm2d)
    else:
        model = UNet()

    if config.iscontinue:
        model = torch.load("./exp/13_Deeplabv3+_0.7619.pth",
                           map_location='cpu').module

    for k, m in model.named_modules():
        m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatability

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.to(device)

    labels = [1, 2, 3, 4, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
    objects = [
        '水体', '道路', '建筑物', '机场', '停车场', '操场', '普通耕地', '农业大棚', '自然草地', '绿地绿化',
        '自然林', '人工林', '自然裸土', '人为裸土', '其它'
    ]
    frequency = np.array([
        0.0279, 0.0797, 0.1241, 0.00001, 0.0616, 0.0029, 0.2298, 0.0107,
        0.1207, 0.0249, 0.1470, 0.0777, 0.0617, 0.0118, 0.0187
    ])

    if config.optimizer == "sgd":
        optimizer = SGD(model.parameters(),
                        lr=config.lr,
                        weight_decay=1e-4,
                        momentum=0.9)
    elif config.optimizer == "adamw":
        optimizer = adamw.AdamW(model.parameters(), lr=config.lr)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    # weight = torch.tensor([1, 1.5, 1, 2, 1.5, 2, 2, 1.2]).to(device)
    # criterion = nn.CrossEntropyLoss(weight=weight)

    if config.smooth == "all":
        criterion = LabelSmoothSoftmaxCE()
    elif config.smooth == "edge":
        criterion = LabelSmoothCE()
    else:
        criterion = nn.CrossEntropyLoss()

    # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[25, 30, 35, 40], gamma=0.5)
    # scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.1, patience=5, verbose=True)
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
                                                         T_0=15,
                                                         eta_min=1e-4)

    global_step = 0
    max_fwiou = 0
    for epoch in range(config.num_epochs):
        epoch_loss = 0.0
        seed = np.random.randint(0, 2, 1)
        seed = 0
        print("seed is ", seed)
        if seed == 1:
            train_loader = get_dataloader(img_dir=config.train_img_dir,
                                          mask_dir=config.train_mask_dir,
                                          mode="train",
                                          batch_size=config.batch_size // 2,
                                          num_workers=config.num_workers,
                                          smooth=config.smooth)
            val_loader = get_dataloader(img_dir=config.val_img_dir,
                                        mask_dir=config.val_mask_dir,
                                        mode="val",
                                        batch_size=config.batch_size // 2,
                                        num_workers=config.num_workers)
        else:
            train_loader = get_dataloader(img_dir=config.train_img_dir,
                                          mask_dir=config.train_mask_dir,
                                          mode="train",
                                          batch_size=config.batch_size,
                                          num_workers=config.num_workers,
                                          smooth=config.smooth)
            val_loader = get_dataloader(img_dir=config.val_img_dir,
                                        mask_dir=config.val_mask_dir,
                                        mode="val",
                                        batch_size=config.batch_size,
                                        num_workers=config.num_workers)

        cm = np.zeros([15, 15])
        print(optimizer.param_groups[0]['lr'])
        with tqdm(total=config.num_train,
                  desc="Epoch %d / %d" % (epoch + 1, config.num_epochs),
                  unit='img',
                  ncols=100) as train_pbar:
            model.train()

            for image, mask in train_loader:
                image = image.to(device, dtype=torch.float32)

                if seed == 0:
                    pass
                elif seed == 1:
                    image = F.interpolate(image,
                                          size=(384, 384),
                                          mode='bilinear',
                                          align_corners=True)
                    mask = F.interpolate(mask.float(),
                                         size=(384, 384),
                                         mode='nearest')

                if config.smooth == "edge":
                    mask = mask.to(device, dtype=torch.float32)
                else:
                    mask = mask.to(device, dtype=torch.long).argmax(dim=1)

                aux_out, out = model(image)
                aux_loss = criterion(aux_out, mask)
                seg_loss = criterion(out, mask)
                loss = aux_loss + seg_loss

                # pred = model(image)
                # loss = criterion(pred, mask)

                epoch_loss += loss.item()

                writer.add_scalar('Loss/train', loss.item(), global_step)
                train_pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_pbar.update(image.shape[0])
                global_step += 1
                # if global_step > 10:
                #     break

            # scheduler.step()
            print("\ntraining epoch loss: " +
                  str(epoch_loss / (float(config.num_train) /
                                    (float(config.batch_size)))))
        torch.cuda.empty_cache()
        val_loss = 0
        with torch.no_grad():
            with tqdm(total=config.num_val,
                      desc="Epoch %d / %d validation round" %
                      (epoch + 1, config.num_epochs),
                      unit='img',
                      ncols=100) as val_pbar:
                model.eval()
                locker = 0
                for image, mask in val_loader:
                    image = image.to(device, dtype=torch.float32)
                    target = mask.to(device, dtype=torch.long).argmax(dim=1)
                    mask = mask.cpu().numpy()
                    _, pred = model(image)
                    val_loss += F.cross_entropy(pred, target).item()
                    pred = pred.cpu().detach().numpy()
                    mask = semantic_to_mask(mask, labels)
                    pred = semantic_to_mask(pred, labels)
                    cm += get_confusion_matrix(mask, pred, labels)
                    val_pbar.update(image.shape[0])
                    if locker == 5:
                        writer.add_images('mask_a/true',
                                          mask[2, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_a/pred',
                                          pred[2, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_b/true',
                                          mask[3, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                        writer.add_images('mask_b/pred',
                                          pred[3, :, :],
                                          epoch + 1,
                                          dataformats='HW')
                    locker += 1

                    # break
                miou = get_miou(cm)
                fw_miou = (miou * frequency).sum()
                scheduler.step()

                if True:
                    if torch.__version__ == "1.6.0":
                        torch.save(model,
                                   config.result_path + "/%d_%s_%.4f.pth" %
                                   (epoch + 1, config.model_type, fw_miou),
                                   _use_new_zipfile_serialization=False)
                    else:
                        torch.save(
                            model, config.result_path + "/%d_%s_%.4f.pth" %
                            (epoch + 1, config.model_type, fw_miou))
                    max_fwiou = fw_miou
                print("\n")
                print(miou)
                print("testing epoch loss: " + str(val_loss),
                      "FWmIoU = %.4f" % fw_miou)
                writer.add_scalar('FWIoU/val', fw_miou, epoch + 1)
                writer.add_scalar('loss/val', val_loss, epoch + 1)
                for idx, name in enumerate(objects):
                    writer.add_scalar('iou/val' + name, miou[idx], epoch + 1)
            torch.cuda.empty_cache()
    writer.close()
    print("Training finished")
Esempio n. 26
0
def train():
    t.cuda.set_device(1)

    # n_channels:医学影像为一通道灰度图    n_classes:二分类
    net = UNet(n_channels=1, n_classes=1)
    optimizer = t.optim.SGD(net.parameters(),
                            lr=opt.learning_rate,
                            momentum=0.9,
                            weight_decay=0.0005)
    criterion = t.nn.BCELoss()  # 二进制交叉熵(适合mask占据图像面积较大的场景)

    start_epoch = 0
    if opt.load_model_path:
        checkpoint = t.load(opt.load_model_path)

        # 加载多GPU模型参数到 单模型上
        state_dict = checkpoint['net']
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        net.load_state_dict(new_state_dict)  # 加载模型
        optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器
        start_epoch = checkpoint['epoch']  # 加载训练批次

    # 学习率每当到达milestones值则更新参数
    if start_epoch == 0:
        scheduler = t.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=opt.milestones,
                                                     gamma=0.1,
                                                     last_epoch=-1)  # 默认为-1
        print('从头训练 ,学习率为{}'.format(optimizer.param_groups[0]['lr']))
    else:
        scheduler = t.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=opt.milestones,
                                                     gamma=0.1,
                                                     last_epoch=start_epoch)
        print('加载预训练模型{}并从{}轮开始训练,学习率为{}'.format(
            opt.load_model_path, start_epoch, optimizer.param_groups[0]['lr']))

    # 网络转移到GPU上
    if opt.use_gpu:
        net = t.nn.DataParallel(net, device_ids=opt.device_ids)  # 模型转为GPU并行
        net.cuda()
        cudnn.benchmark = True

    # 定义可视化对象
    vis = Visualizer(opt.env)

    train_data = NodeDataSet(train=True)
    val_data = NodeDataSet(val=True)
    test_data = NodeDataSet(test=True)

    # 数据集加载器
    train_dataloader = DataLoader(train_data,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers)
    val_dataloader = DataLoader(val_data,
                                opt.batch_size,
                                shuffle=True,
                                num_workers=opt.num_workers)
    test_dataloader = DataLoader(test_data,
                                 opt.test_batch_size,
                                 shuffle=False,
                                 num_workers=opt.num_workers)
    for epoch in range(opt.max_epoch - start_epoch):
        print('开始 epoch {}/{}.'.format(start_epoch + epoch + 1, opt.max_epoch))
        epoch_loss = 0

        # 每轮判断是否更新学习率
        scheduler.step()

        # 迭代数据集加载器
        for ii, (img, mask) in enumerate(
                train_dataloader):  # pytorch0.4写法,不再将tensor封装为Variable
            # 将数据转到GPU
            if opt.use_gpu:
                img = img.cuda()
                true_masks = mask.cuda()
            masks_pred = net(img)

            # 经过sigmoid
            masks_probs = t.sigmoid(masks_pred)

            # 损失 = 二进制交叉熵损失 + dice损失
            loss = criterion(masks_probs.view(-1), true_masks.view(-1))

            # 加入dice损失
            if opt.use_dice_loss:
                loss += dice_loss(masks_probs, true_masks)

            epoch_loss += loss.item()

            if ii % 2 == 0:
                vis.plot('训练集loss', loss.item())

            # 优化器梯度清零
            optimizer.zero_grad()
            # 反向传播
            loss.backward()
            # 更新参数
            optimizer.step()

        # 当前时刻的一些信息
        vis.log("epoch:{epoch},lr:{lr},loss:{loss}".format(
            epoch=epoch, loss=loss.item(), lr=optimizer.param_groups[0]['lr']))

        vis.plot('每轮epoch的loss均值', epoch_loss / ii)
        # 保存模型、优化器、当前轮次等
        state = {
            'net': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch
        }
        t.save(state, opt.checkpoint_root + '{}_unet.pth'.format(epoch))

        # ============验证===================

        net.eval()
        # 评价函数:Dice系数    Dice距离用于度量两个集合的相似性
        tot = 0
        for jj, (img_val, mask_val) in enumerate(val_dataloader):
            img_val = img_val
            true_mask_val = mask_val
            if opt.use_gpu:
                img_val = img_val.cuda()
                true_mask_val = true_mask_val.cuda()

            mask_pred = net(img_val)
            mask_pred = (t.sigmoid(mask_pred) > 0.5).float()  # 阈值为0.5
            # 评价函数:Dice系数   Dice距离用于度量两个集合的相似性
            tot += dice_loss(mask_pred, true_mask_val).item()
        val_dice = tot / jj
        vis.plot('验证集 Dice损失', val_dice)

        # ============验证召回率===================
        # 每10轮验证一次测试集召回率
        if epoch % 10 == 0:
            result_test = []
            for kk, (img_test, mask_test) in enumerate(test_dataloader):
                # 测试 unet分割能力,故 不使用真值mask
                if opt.use_gpu:
                    img_test = img_test.cuda()
                mask_pred_test = net(img_test)  # [1,1,512,512]

                probs = t.sigmoid(mask_pred_test).squeeze().squeeze().cpu(
                ).detach().numpy()  # [512,512]
                mask = probs > opt.out_threshold
                result_test.append(mask)

            # 得到 测试集所有预测掩码,计算二维召回率
            vis.plot('测试集二维召回率', getRecall(result_test).getResult())
        net.train()
Esempio n. 27
0
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--jaccard-weight', default=0.3, type=float)
    arg('--device-ids',
        type=str,
        default='0',
        help='For example 0,1 to run on two GPUs')
    arg('--fold', type=int, help='fold', default=0)
    arg('--root', default='runs/debug', help='checkpoint root')
    arg('--batch-size', type=int, default=1)
    arg('--limit', type=int, default=10000, help='number of images in epoch')
    arg('--n-epochs', type=int, default=100)
    arg('--lr', type=float, default=0.0001)
    arg('--workers', type=int, default=12)
    arg('--model',
        type=str,
        default='UNet',
        choices=['UNet', 'UNet11', 'UNet16', 'AlbuNet34'])

    args = parser.parse_args()

    root = Path(args.root)
    root.mkdir(exist_ok=True, parents=True)

    num_classes = 1
    if args.model == 'UNet':
        model = UNet(num_classes=num_classes)
    elif args.model == 'UNet11':
        model = UNet11(num_classes=num_classes, pretrained=True)
    elif args.model == 'UNet16':
        model = UNet16(num_classes=num_classes, pretrained=True)
    elif args.model == 'LinkNet34':
        model = LinkNet34(num_classes=num_classes, pretrained=True)
    elif args.model == 'AlbuNet':
        model = AlbuNet34(num_classes=num_classes, pretrained=True)
    else:
        model = UNet(num_classes=num_classes, input_channels=3)

    if torch.cuda.is_available():
        if args.device_ids:
            device_ids = list(map(int, args.device_ids.split(',')))
        else:
            device_ids = None
        model = nn.DataParallel(model, device_ids=device_ids).cuda()

    loss = LossBinary(jaccard_weight=args.jaccard_weight)

    cudnn.benchmark = True

    def make_loader(file_names, shuffle=False, transform=None, limit=None):
        return DataLoader(dataset=AngyodysplasiaDataset(file_names,
                                                        transform=transform,
                                                        limit=limit),
                          shuffle=shuffle,
                          num_workers=args.workers,
                          batch_size=args.batch_size,
                          pin_memory=torch.cuda.is_available())

    train_file_names, val_file_names = get_split(args.fold)

    print('num train = {}, num_val = {}'.format(len(train_file_names),
                                                len(val_file_names)))

    train_transform = DualCompose([
        SquarePaddingTraining(),
        CenterCrop([574, 574]),
        HorizontalFlip(),
        VerticalFlip(),
        Rotate(),
        ImageOnly(RandomHueSaturationValue()),
        ImageOnly(Normalize())
    ])

    val_transform = DualCompose([
        SquarePaddingTraining(),
        CenterCrop([574, 574]),
        ImageOnly(Normalize())
    ])

    train_loader = make_loader(train_file_names,
                               shuffle=True,
                               transform=train_transform,
                               limit=args.limit)
    valid_loader = make_loader(val_file_names, transform=val_transform)

    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))

    utils.train(init_optimizer=lambda lr: Adam(model.parameters(), lr=lr),
                args=args,
                model=model,
                criterion=loss,
                train_loader=train_loader,
                valid_loader=valid_loader,
                validation=validation_binary,
                fold=args.fold)
Esempio n. 28
0
def start():
    parser = argparse.ArgumentParser(
        description='UNet + BDCLSTM for BraTS Dataset')
    parser.add_argument('--batch-size',
                        type=int,
                        default=4,
                        metavar='N',
                        help='input batch size for training (default: 4)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=4,
                        metavar='N',
                        help='input batch size for testing (default: 4)')
    parser.add_argument('--train',
                        action='store_true',
                        default=False,
                        help='Argument to train model (default: False)')
    parser.add_argument('--epochs',
                        type=int,
                        default=2,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='enables CUDA training (default: False)')
    parser.add_argument('--log-interval',
                        type=int,
                        default=1,
                        metavar='N',
                        help='batches to wait before logging training status')
    parser.add_argument('--size',
                        type=int,
                        default=128,
                        metavar='N',
                        help='imsize')
    parser.add_argument('--load',
                        type=str,
                        default=None,
                        metavar='str',
                        help='weight file to load (default: None)')
    parser.add_argument('--data',
                        type=str,
                        default='./Data/',
                        metavar='str',
                        help='folder that contains data')
    parser.add_argument('--save',
                        type=str,
                        default='OutMasks',
                        metavar='str',
                        help='Identifier to save npy arrays with')
    parser.add_argument('--modality',
                        type=str,
                        default='flair',
                        metavar='str',
                        help='Modality to use for training (default: flair)')
    parser.add_argument('--optimizer',
                        type=str,
                        default='SGD',
                        metavar='str',
                        help='Optimizer (default: SGD)')

    args = parser.parse_args()
    args.cuda = args.cuda and torch.cuda.is_available()

    DATA_FOLDER = args.data

    # %% Loading in the model
    # Binary
    # model = UNet(num_channels=1, num_classes=2)
    # Multiclass
    model = UNet(num_channels=1, num_classes=3)

    if args.cuda:
        model.cuda()

    if args.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.99)
    if args.optimizer == 'ADAM':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               betas=(args.beta1, args.beta2))

    # Defining Loss Function
    criterion = DICELossMultiClass()

    if args.train:
        # %% Loading in the Dataset
        full_dataset = BraTSDatasetUnet(DATA_FOLDER,
                                        im_size=[args.size, args.size],
                                        transform=tr.ToTensor())
        #dset_test = BraTSDatasetUnet(DATA_FOLDER, train=False,
        # keywords=[args.modality], im_size=[args.size,args.size], transform=tr.ToTensor())

        train_size = int(0.9 * len(full_dataset))
        test_size = len(full_dataset) - train_size
        train_dataset, validation_dataset = torch.utils.data.random_split(
            full_dataset, [train_size, test_size])

        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=1)
        validation_loader = DataLoader(validation_dataset,
                                       batch_size=args.test_batch_size,
                                       shuffle=False,
                                       num_workers=1)
        #test_loader = DataLoader(full_dataset, batch_size=args.test_batch_size, shuffle=False, num_workers=1)

        print("Training Data : ", len(train_loader.dataset))
        print("Validaion Data : ", len(validation_loader.dataset))
        #print("Test Data : ", len(test_loader.dataset))

        loss_list = []
        start = timer()
        for i in tqdm(range(args.epochs)):
            train(model, i, loss_list, train_loader, optimizer, criterion,
                  args)
            test(model, validation_loader, criterion, args, validation=True)
        end = timer()
        print("Training completed in {:0.2f}s".format(end - start))

        plt.plot(loss_list)
        plt.title("UNet bs={}, ep={}, lr={}".format(args.batch_size,
                                                    args.epochs, args.lr))
        plt.xlabel("Number of iterations")
        plt.ylabel("Average DICE loss per batch")
        plt.savefig("./plots/{}-UNet_Loss_bs={}_ep={}_lr={}.png".format(
            args.save, args.batch_size, args.epochs, args.lr))

        np.save(
            './npy-files/loss-files/{}-UNet_Loss_bs={}_ep={}_lr={}.npy'.format(
                args.save, args.batch_size, args.epochs, args.lr),
            np.asarray(loss_list))
        print("Testing Validation")
        test(model, validation_loader, criterion, args, save_output=True)
        torch.save(
            model.state_dict(),
            'unet-multiclass-model-{}-{}-{}'.format(args.batch_size,
                                                    args.epochs, args.lr))

        print("Testing PDF images")
        test_dataset = TestDataset('./pdf_data/',
                                   im_size=[args.size, args.size],
                                   transform=tr.ToTensor())
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.test_batch_size,
                                 shuffle=False,
                                 num_workers=1)
        print("Test Data : ", len(test_loader.dataset))
        test_only(model, test_loader, criterion, args)

    elif args.load is not None:
        test_dataset = TestDataset(DATA_FOLDER,
                                   im_size=[args.size, args.size],
                                   transform=tr.ToTensor())
        test_loader = DataLoader(test_dataset,
                                 batch_size=args.test_batch_size,
                                 shuffle=False,
                                 num_workers=1)
        print("Test Data : ", len(test_loader.dataset))
        model.load_state_dict(torch.load(args.load))
        test_only(model, test_loader, criterion, args)
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument
    arg('--device-ids', type=str, default='0', help='For example 0,1 to run on two GPUs')
    arg('--fold-out', type=int, help='fold train test', default=0)
    arg('--fold-in', type=int, help='fold train val', default=0)
    arg('--percent', type=float, help='percent of data', default=1)
    arg('--root', default='runs/debug', help='checkpoint root')
    arg('--batch-size', type=int, default=4)
    arg('--limit', type=int, default=10000, help='number of images in epoch')
    arg('--n-epochs', type=int, default=40)
    arg('--n-steps', type=int, default=200)
    arg('--lr', type=float, default=0.003) 
    arg('--modelVHR', type=str, default='UNet11', choices=['UNet11','UNet','AlbuNet34','SegNet'])
    arg('--dataset-path-HR', type=str, default='data_HR', help='ain path  of the HR dataset')
    arg('--model-path-HR', type=str, default='logs_HR/mapping/model_40epoch_HR_UNet11.pth', help='path of the model of HR')
    arg('--dataset-path-VHR', type=str, default='data_VHR', help='ain path  of the VHR dataset')
    arg('--name-file-HR', type=str, default='_HR', help='name file of HR dataset')
    arg('--dataset-file', type=str, default='VHR', help='main dataset resolution,depend of this correspond a specific crop' )
    arg('--out-file', type=str, default='seq', help='the file in which save the outputs')
    arg('--train-val-file-HR', type=str, default='train_val_HR', help='name of the train-val file' )
    arg('--test-file-HR', type=str, default='test_HR', help='name of the test file' )
    arg('--train-val-file-VHR', type=str, default='train_val_850', help='name of the train-val file' )
    arg('--test-file-VHR', type=str, default='test_850', help='name of the test file' )
    
    args = parser.parse_args()
    
    root = Path(args.root)
    root.mkdir(exist_ok=True, parents=True)

    num_classes = 1 
    input_channels=4

    if args.modelVHR == 'UNet11':
        model_VHR = UNet11(num_classes=num_classes, input_channels=input_channels)
    elif args.modelVHR == 'UNet':
        model_VHR = UNet(num_classes=num_classes, input_channels=input_channels)
    elif args.modelVHR == 'AlbuNet34':
        model_VHR =AlbuNet34(num_classes=num_classes, num_input_channels=input_channels, pretrained=False)
    elif args.modelVHR == 'SegNet':
        model_VHR = SegNet(num_classes=num_classes, num_input_channels=input_channels, pretrained=False)
    else:
        model_VHR = UNet11(num_classes=num_classes, input_channels=4)

    if torch.cuda.is_available():
        if args.device_ids:#
            device_ids = list(map(int, args.device_ids.split(',')))
        else:
            device_ids = None
        model_VHR = nn.DataParallel(model_VHR, device_ids=device_ids).cuda()

    cudnn.benchmark = True


    out_path = Path(('logs_{}/mapping/').format(args.out_file))
    
    #Data-paths:--------------------------VHr-------------------------------------
    data_path_VHR = Path(args.dataset_path_VHR) 
    print("data_path:",data_path_VHR)
  

    name_file_VHR = '_'+ str(int(args.percent*100))+'_percent_'+args.out_file
    data_all='data'
    ##--------------------------------------
 
   ############################  
    # NEstes cross validation K-fold train test
    ##train_val_file_names, test_file_names_HR = get_split_out(data_path_HR,data_all,args.fold_out)
    ############################  
   ############################  Cross validation
    train_val_file_names=np.array(sorted(glob.glob(str((data_path_VHR/args.train_val_file_VHR/'images'))+ "/*.npy")))
    test_file_names_VHR =  np.array(sorted(glob.glob(str((data_path_VHR/args.test_file_VHR/'images')) + "/*.npy")))
    
    if args.percent !=1:
        extra, train_val_file_names= percent_split(train_val_file_names, args.percent) 

    train_file_VHR_lab,val_file_VHR_lab = get_split_in(train_val_file_names,args.fold_in)
    np.save(str(os.path.join(out_path,"train_files{}_{}_fold{}_{}.npy".format(name_file_VHR, args.modelVHR, args.fold_out, args.fold_in))), train_file_VHR_lab)
    np.save(str(os.path.join(out_path,"val_files{}_{}_fold{}_{}.npy". format(name_file_VHR, args.modelVHR, args.fold_out, args.fold_in))), val_file_VHR_lab)

      #Data-paths:--------------------------unlabeled VHR-------------------------------------    
    
    train_path_VHR_unlab= data_path_VHR/'unlabel'/'train'/'images'
    val_path_VHR_unlab = data_path_VHR/'unlabel'/'val'/'images'
    
    

    train_file_VHR_unlab = np.array(sorted(list(train_path_VHR_unlab.glob('*.npy'))))
    val_file_VHR_unlab = np.array(sorted(list(val_path_VHR_unlab.glob('*.npy'))))
   
    print('num train_lab = {}, num_val_lab = {}'.format(len(train_file_VHR_lab), len(val_file_VHR_lab)))
    print('num train_unlab = {}, num_val_unlab = {}'.format(len(train_file_VHR_unlab), len(val_file_VHR_unlab)))
    
    max_values_VHR, mean_values_VHR, std_values_VHR=meanstd(train_file_VHR_lab, val_file_VHR_lab,test_file_names_VHR,str(data_path_VHR),input_channels)

    def make_loader(file_names, shuffle=False, transform=None,mode='train',batch_size=4, limit=None):
        return DataLoader(
            dataset=WaterDataset(file_names, transform=transform,mode=mode, limit=limit),
            shuffle=shuffle,            
            batch_size=batch_size, 
            pin_memory=torch.cuda.is_available() 

        )
 #transformations ---------------------------------------------------------------------------      
        
    train_transform_VHR = DualCompose([
            CenterCrop(512),
            HorizontalFlip(),
            VerticalFlip(),
            Rotate(),
            ImageOnly(Normalize(mean=mean_values_VHR,std= std_values_VHR))
        ])
    
    val_transform_VHR = DualCompose([
            CenterCrop(512),
            ImageOnly(Normalize(mean=mean_values_VHR, std=std_values_VHR))
        ])
#-------------------------------------------------------------------      
    mean_values_HR=(0.11952524, 0.1264638 , 0.13479991, 0.15017026)
    std_values_HR=(0.08844988, 0.07304429, 0.06740904, 0.11003125)
    
    train_transform_VHR_unlab = DualCompose([
            CenterCrop(512),
            HorizontalFlip(),
            VerticalFlip(),
            Rotate(),
            ImageOnly(Normalize(mean=mean_values_HR,std= std_values_HR))
        ])
    
    val_transform_VHR_unlab = DualCompose([
            CenterCrop(512),
            ImageOnly(Normalize(mean=mean_values_HR, std=std_values_HR))
        ])
    

######################## DATA-LOADERS ###########################################################49
    train_loader_VHR_lab = make_loader(train_file_VHR_lab, shuffle=True, transform=train_transform_VHR , batch_size = 2, mode = "train")
    valid_loader_VHR_lab = make_loader(val_file_VHR_lab, transform=val_transform_VHR, batch_size = 4, mode = "train")
    
    train_loader_VHR_unlab = make_loader(train_file_VHR_unlab, shuffle=True, transform=train_transform_VHR, batch_size = 4, mode = "unlb_train")
    valid_loader_VHR_unlab = make_loader(val_file_VHR_unlab, transform=val_transform_VHR, batch_size = 2, mode = "unlb_val")

    
    dataloaders_VHR_lab= {
        'train': train_loader_VHR_lab, 'val': valid_loader_VHR_lab
    }
    
    dataloaders_VHR_unlab= {
        'train': train_loader_VHR_unlab, 'val': valid_loader_VHR_unlab
    }

#----------------------------------------------    
    root.joinpath(('params_{}.json').format(args.out_file)).write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))
    
    # Observe that all parameters are being optimized
    optimizer_ft = optim.Adam(model_VHR.parameters(), lr= args.lr)  
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=20, gamma=0.1) 

#--------------------------model HR-------------------------------------
    PATH_HR= args.model_path_HR

    #Initialise the model
    model_HR = UNet11(num_classes=num_classes)
    model_HR.cuda()
    model_HR.load_state_dict(torch.load(PATH_HR))
#---------------------------------------------------------------
    model_VHR= utilsTrain_seq.train_model(
        out_file=args.out_file,
        name_file_VHR=name_file_VHR,
        model_HR=model_HR, 
        model_VHR=model_VHR,
        optimizer=optimizer_ft,
        scheduler=exp_lr_scheduler,
        dataloaders_VHR_lab=dataloaders_VHR_lab,
        dataloaders_VHR_unlab=dataloaders_VHR_unlab,
        fold_out=args.fold_out,
        fold_in=args.fold_in,
        name_model_VHR=args.modelVHR,
        n_steps=args.n_steps,
        num_epochs=args.n_epochs 
        
        )


    torch.save(model_VHR.module.state_dict(), (str(out_path)+'/model{}_{}_foldout{}_foldin{}_{}epochs.pth').format(args.n_epochs,name_file_VHR,args.modelVHR, args.fold_out,args.fold_in,args.n_epochs))

    print(args.modelVHR)
    max_values_all_VHR=3521

    find_metrics(train_file_names=train_file_VHR_lab, 
                 val_file_names=val_file_VHR_lab,
                 test_file_names=test_file_names_VHR, 
                 max_values=max_values_all_VHR, 
                 mean_values=mean_values_VHR, 
                 std_values=std_values_VHR, 
                 model=model_VHR, 
                 fold_out=args.fold_out, 
                 fold_in=args.fold_in,
                 name_model=args.modelVHR,
                 epochs=args.n_epochs, 
                 out_file=args.out_file, 
                 dataset_file=args.dataset_file,
                 name_file=name_file_VHR)
Esempio n. 30
0
class TimeTransfer(pl.LightningModule):
    def __init__(self, hparams):
        super(TimeTransfer, self).__init__()
        self.hparams = hparams

        # networks
        self.unet = UNet(3, hparams.hidden_dim)

        self.data_dir = os.path.expanduser(hparams.data_dir)

        self.split_indices = prefix_sum(hparams.data_split)

        self.device = torch.device('cuda' if hparams.gpus > 0 else 'cpu')
        self.criteria = PerceptualLoss().to(self.device)

        # self.example_input_array = torch.zeros((4, 3, 450, 800)), torch.tensor([3, 6, 12, 21])
        #self.optimizer = torch.optim.Adam(self.unet.parameters(), lr=self.hparams.lr)

    def forward(self, x, t):
        t = t * torch.ones(x.shape[0]).to(x.device)
        return self.unet(x, t)

    def get_time_batch(self, batch, t):
        x = batch[t]
        return x

    def training_step(self, batch, batch_nb):
        # REQUIRED
        source_hour = torch.randint(0, 23, (1, )).item()
        target_hour = torch.randint(0, 23, (1, )).item()
        x = self.get_time_batch(batch, source_hour)
        y = self.get_time_batch(batch, target_hour)
        y_hat = self.forward(x, target_hour)
        loss = self.criteria(y_hat, y)
        #loss = F.mse_loss(y_hat, y)
        #print(loss - loss_test)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb):
        # OPTIONAL
        source_hour = torch.randint(0, 23, (1, )).item()
        target_hour = torch.randint(0, 23, (1, )).item()
        x = self.get_time_batch(batch, source_hour)
        y = self.get_time_batch(batch, target_hour)
        y_hat = self.forward(x, target_hour)
        loss = F.mse_loss(y_hat, y)
        return {'val_loss': loss}

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

    def test_step(self, batch, batch_nb):
        # OPTIONAL
        source_hour = torch.randint(0, 23, (1, )).item()
        target_hour = source_hour
        x = self.get_time_batch(batch, source_hour)
        y = self.get_time_batch(batch, target_hour)
        y_hat = self.forward(x, target_hour)
        loss = F.mse_loss(y_hat, y)
        return {'test_loss': loss}

    def test_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        logs = {'test_loss': avg_loss}
        return {'avg_test_loss': avg_loss, 'log': logs, 'progress_bar': logs}

    def on_epoch_end(self):
        # log sampled images
        dataset = self.test_dataloader()[0].dataset
        samples = dataset[:self.hparams.n_samples]
        source_hour = torch.randint(0, 23, (1, )).item()
        target_hour = torch.randint(0, 23, (1, )).item()
        x = self.get_time_batch(samples, source_hour).to(self.device)
        y = self.get_time_batch(samples, target_hour).to(self.device)
        y_hat = self.forward(x, target_hour)
        grid = torchvision.utils.make_grid(torch.stack([x, y_hat, y],
                                                       dim=1).view(
                                                           -1, 3, 450, 800),
                                           nrow=3)
        self.logger.experiment.add_image(f'samples', grid, self.current_epoch)

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        optimizer = torch.optim.Adam(self.unet.parameters(),
                                     lr=self.hparams.lr)
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lambda x: 0.5 if x in [3, 7, 12] else 1, last_epoch=-1)
        return [optimizer], [scheduler]

    @pl.data_loader
    def train_dataloader(self):
        # REQUIRED
        return DataLoader(Subset(TimedImageDataset(self.data_dir),
                                 range(self.split_indices[0])),
                          batch_size=self.hparams.batch_size,
                          shuffle=True,
                          num_workers=3)

    @pl.data_loader
    def val_dataloader(self):
        # OPTIONAL
        return DataLoader(Subset(
            TimedImageDataset(self.data_dir),
            range(self.split_indices[0], self.split_indices[1])),
                          batch_size=self.hparams.batch_size,
                          shuffle=True,
                          num_workers=3)

    @pl.data_loader
    def test_dataloader(self):
        # OPTIONAL
        return DataLoader(Subset(
            TimedImageDataset(self.data_dir),
            range(self.split_indices[1], self.split_indices[2])),
                          batch_size=self.hparams.batch_size,
                          shuffle=True,
                          num_workers=3)