Exemplo n.º 1
0
    def __init__(self, n_channels, n_classes):
        super(UNet_crf, self).__init__()

        self.down1 = double_conv(n_channels, 32, 32)
        self.down2 = down_step(32, 64)

        self.bottom_bridge = down_step(64, 128)

        self.up1 = up_step(128, 64)
        self.up2 = up_step(64, 32)

        self.outconv = out_conv(32, n_classes)

        self.crf = CRF(n_spatial_dims=3)
Exemplo n.º 2
0
video_name = "sample1"
# height, width = 720, 1280
video_path = str(data_path / "origin" / f"{video_name}.mp4")
resized_video_path = str(data_path / f"{video_name}_{height}x{width}.mp4")
resized_frames_path = data_path / f"{video_name}_{height}x{width}"
if not resized_frames_path.exists():
    resized_frames_path.mkdir()

frames_path = resized_frames_path / "img_dir"

cityspaces_path = repo_path.parent / "data" / "cityscapes"
device = "cuda"
batch_size = 2

seg_model = init_segmentor(config_file, checkpoint_file, device=device)
crf = CRF(n_spatial_dims=2, returns="log-proba").to(device)

cfg = seg_model.cfg
train_dataset = CityscapesDataset(data_root=cityspaces_path,
                                  pipeline=cfg.data.train.pipeline,
                                  img_dir=cfg.data.train.img_dir,
                                  ann_dir=cfg.data.train.ann_dir,
                                  test_mode=False)

val_dataset = CityscapesDataset(data_root=cityspaces_path,
                                pipeline=cfg.data.val.pipeline,
                                img_dir=cfg.data.val.img_dir,
                                ann_dir=cfg.data.val.ann_dir,
                                test_mode=False)

train_loader = build_dataloader(train_dataset,
Exemplo n.º 3
0
    nn.Conv2d(dataset.n_modalities, 8, kernel_size=3, padding=1),
    layers.FPN(layers.ResBlock2d,
               downsample=nn.MaxPool2d(2, ceil_mode=True),
               upsample=nn.Identity,
               merge=lambda left, down: torch.cat(
                   layers.interpolate_to_left(left, down, 'bilinear'), dim=1),
               structure=[[[8, 8, 8], [16, 8, 8]], [[8, 16, 16], [32, 16, 8]],
                          [[16, 32, 32], [64, 32, 16]],
                          [[32, 64, 64], [128, 64, 32]],
                          [[64, 128, 128], [256, 128, 64]],
                          [[128, 256, 256], [512, 256, 128]], [256, 512, 256]],
               kernel_size=3,
               padding=1), layers.PreActivation2d(8, 1, kernel_size=1))

model = CRFWrapper(Quasi3DWrapper(unet_2d),
                   CRF(n_spatial_dims=3, filter_size=5)).to(CONFIG['device'])

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'])
with_crf = Switch(False, epoch_to_value={100: True})


# predict
@unpack_args
@add_extract_dims(1, 2)
def predict(image, spacing):
    return inference_step(image,
                          spacing,
                          architecture=model,
                          activation=torch.sigmoid)
Exemplo n.º 4
0
def train(train_dir, val_dir, checkpoint_file=None):
    train_loader, val_loader = load_data(train_dir, val_dir)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = models.segmentation.deeplabv3_resnet50(pretrained=False,
                                                   num_classes=2)

    for param in model.parameters():
        param.requires_grad = True

    model = nn.Sequential(model, CRF(n_spatial_dims=2))

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.002)
    model.to(device)

    writer = SummaryWriter('runs/deeplab_experiment_11')

    epochs = 100
    print_freq = 500
    save_freq = 25

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lambda x: (1 - x / (len(train_loader) * epochs))**0.9)

    if checkpoint_file is not None:
        checkpoint = torch.load(checkpoint_file)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        last_epoch = checkpoint['epoch']
        acc_global = checkpoint['acc_global']
        acc = checkpoint['acc']
        iou = checkpoint['iou']
        miou = checkpoint['miou']
        dice_coef = checkpoint['dice_coef']
        mcc = checkpoint['mcc']
    else:
        last_epoch = 0

    for epoch in range(last_epoch + 1, last_epoch + epochs + 1):
        train_mcc = 0.0
        train_confmat = utils.ConfusionMatrix(num_classes=2)

        for inputs, labels in train_loader:
            model.train()
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            try:
                logps = model(inputs)
                #logps = logps['out'].squeeze(1)
            except:
                continue

            loss = criterion(logps, labels)
            loss.backward()
            optimizer.step()
            lr_scheduler.step(loss.cpu().data.numpy())

            train_pred = torch.argmax(logps, dim=1)
            train_confmat.update(labels.flatten().long(), train_pred.flatten())
            train_mcc += matthews_corrcoef(labels.cpu().numpy().flatten(),
                                           train_pred.cpu().numpy().flatten())

        train_acc_global, train_acc, train_iou, train_miou = train_confmat.compute(
        )
        train_mcc = train_mcc / len(train_loader)
        print("Train loss: ", loss.item(), " ... ", "Train acc: ",
              train_acc_global, " ... ", "Train mIOU: ", train_miou, " ... ",
              "Train MCC: ", train_mcc)

        writer.add_scalar('Train/Loss', loss.item(), epoch)
        writer.add_scalar('Train/Global Accuracy', train_acc_global, epoch)
        writer.add_scalar('Train/Accuracy/nontumor', train_acc[0], epoch)
        writer.add_scalar('Train/Accuracy/tumor', train_acc[1], epoch)
        writer.add_scalar('Train/IoU/nontumor', train_iou[0], epoch)
        writer.add_scalar('Train/IoU/tumor', train_iou[1], epoch)
        writer.add_scalar('Train/mIoU', train_miou, epoch)
        writer.add_scalar('Train/MCC', train_mcc, epoch)

        # Evaluate validation loss and confusion matrix after every epoch
        model.eval()
        val_loss = 0.0
        val_mcc = 0.0
        confmat = utils.ConfusionMatrix(num_classes=2)

        with torch.no_grad():
            for val_inputs, val_labels in val_loader:
                val_inputs, val_labels = val_inputs.to(device), val_labels.to(
                    device)

                try:
                    val_logps = model(val_inputs)
                    #val_logps = val_logps['out'].squeeze(1)
                except:
                    continue

                val_preds = torch.argmax(val_logps, dim=1)
                probability = torch.sigmoid(val_logps)
                predicted = (probability > 0.5).int()
                confmat.update(val_labels.flatten().long(),
                               val_preds.flatten())
                val_mcc += matthews_corrcoef(
                    val_labels.cpu().numpy().flatten(),
                    val_preds.cpu().numpy().flatten())
                batch_loss = criterion(val_logps, val_labels)
                val_loss += batch_loss.item()

        val_loss = val_loss / len(val_loader)
        acc_global, acc, iou, miou = confmat.compute()
        dice_coef = dice(val_preds.flatten(), val_labels.flatten())
        val_mcc = val_mcc / len(val_loader)
        print("Val loss: ", val_loss, " ... ", "Val acc: ", acc_global,
              " ... ", "Val mIOU: ", miou, " ... ", "Val Dice coeff: ",
              dice_coef, "Val MCC: ", val_mcc)

        writer.add_scalar('Val/Loss', val_loss, epoch)
        writer.add_scalar('Val/Global Accuracy', acc_global, epoch)
        writer.add_scalar('Val/Accuracy/nontumor', acc[0], epoch)
        writer.add_scalar('Val/Accuracy/tumor', acc[1], epoch)
        writer.add_scalar('Val/IoU/nontumor', iou[0], epoch)
        writer.add_scalar('Val/IoU/tumor', iou[1], epoch)
        writer.add_scalar('Val/mIoU', miou, epoch)
        writer.add_scalar('Val/Dice coeff', dice_coef, epoch)
        writer.add_scalar('Val/MCC', val_mcc, epoch)
        writer.close()

        # Save checkpoint after every save_freq epochs
        if epoch % save_freq == 0:
            utils.save_on_master(
                {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'epoch': epoch,
                    'acc_global': acc_global,
                    'acc': acc,
                    'iou': iou,
                    'miou': miou,
                    'dice_coef': dice_coef,
                    'mcc': val_mcc
                },
                os.path.join(
                    './checkpoints',
                    'deeplab_resnet101_experiment_11_{}.pth'.format(epoch)))
Exemplo n.º 5
0
import matplotlib.pyplot as plt
import torch
from torch.utils import data
from torchvision import models, transforms
from transforms_sample import *
from crfseg import CRF

val_dir = '/home/steveyang/projects/camelyon17/tile_images/deeplab/cross_val_neg/val/'
images_dir = os.path.join(val_dir, 'images/')
masks_dir = os.path.join(val_dir, 'masks/')
mean = (182.5253448486328, 182.49656677246094, 182.4678192138672)
std = (36.937557220458984, 36.780677795410156, 36.90703582763672)

checkpoint_path = '/home/steveyang/projects/camelyon17/deeplab/checkpoints/deeplab_resnet101_experiment_10_100.pth'
model = models.segmentation.deeplabv3_resnet50(num_classes=2)
model = torch.nn.Sequential(model, CRF(n_spatial_dims=2))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model'])

images_files = os.listdir(images_dir)


def inference_random(images_files):
    filename = choice(images_files)
    input_image = Image.open(os.path.join(images_dir, filename))
    input_mask = Image.open(os.path.join(masks_dir, filename))

    preprocess = transforms.Compose([
        Resize(256),
        ToTensor(),
Exemplo n.º 6
0
    def __init__(self,
                 input_height: int,
                 enc_type: str = 'mnist_conv',
                 num_classes: int = 10,
                 first_conv: bool = False,
                 maxpool1: bool = False,
                 enc_out_dim: int = 64,
                 kl_coeff: float = 0.1,
                 latent_dim: int = 64,
                 lr: float = 1e-4,
                 k: int = 100,
                 input_channels: int = 1,
                 py_mode: int = 0,
                 recon_loss_type: str = 'l2',
                 mlp_hidden_dim: int = 500,
                 vae_type: str = 'gfz',
                 no_decoder: bool = False,
                 per_class: bool = False,
                 binary: bool = False,
                 **kwargs):
        """
        Args:
            input_height: height of the images
            enc_type: option between resnet18 or resnet50
            first_conv: use standard kernel_size 7, stride 2 at start or
                replace it with kernel_size 3, stride 1 conv
            maxpool1: use standard maxpool to reduce spatial dim of feat by a factor of 2
            enc_out_dim: set according to the out_channel count of
                encoder used (512 for resnet18, 2048 for resnet50)
            kl_coeff: coefficient for kl term of the loss
            latent_dim: dim of latent space
            lr: learning rate for Adam
            k: number of samples on latent variables during prediction time
        """

        super(VAE, self).__init__()

        self.save_hyperparameters()

        self.lr = lr
        self.kl_coeff = kl_coeff
        self.enc_out_dim = enc_out_dim
        self.latent_dim = latent_dim
        self.input_height = input_height
        self.num_classes = num_classes
        self.k = k
        self.input_channels = input_channels
        self.py_mode = py_mode
        self.recon_loss_type = recon_loss_type
        self.mlp_hidden_dim = mlp_hidden_dim
        self.vae_type = vae_type
        self.no_decoder = no_decoder
        self.per_class = per_class
        self.binary = binary

        self.example_input_array = torch.rand((1, 1, 28, 28))

        valid_encoders = {
            'resnet18': {
                'enc': resnet18_encoder,
                'dec': resnet18_decoder,
            },
            'resnet50': {
                'enc': resnet50_encoder,
                'dec': resnet50_decoder,
            },
            'mnist_conv': {
                'enc': mnist_encoder,
                'dec': mnist_decoder,
            }
        }
        self.valid_encoders = valid_encoders
        self.enc_type = enc_type
        print("vae type:", vae_type)
        if vae_type == 'gfz':
            self.feat_recon_input_dim = self.latent_dim + self.num_classes
        elif vae_type == 'gbz':
            self.feat_recon_input_dim = self.latent_dim

        mlp_y_layers = 2

        if enc_type not in valid_encoders:
            raise Exception("Invalid encoder " + str(enc_type))
        else:
            # self.encoder = valid_encoders[enc_type]['enc']()
            if 'mnist' in enc_type:
                mlp_y_layers = 1
                self.encoder = valid_encoders[enc_type]['enc'](
                    kernel_sizes=[5, 5, 5],
                    strides=[1, 1, 1],
                    n_channels=[64, 64, 64],
                    maxpool=True)
                if per_class:
                    self.decoder = nn.ModuleList([
                        valid_encoders[enc_type]['dec'](
                            self.enc_out_dim,
                            recon_loss_type=self.recon_loss_type)
                        for i in range(self.num_classes)
                    ])
                else:
                    self.decoder = valid_encoders[enc_type]['dec'](
                        self.enc_out_dim, recon_loss_type=self.recon_loss_type)
            else:
                self.encoder = valid_encoders[enc_type]['enc'](first_conv,
                                                               maxpool1)
                if per_class:
                    self.decoder = nn.ModuleList([
                        valid_encoders[enc_type]['dec'](self.enc_out_dim,
                                                        self.input_height,
                                                        first_conv, maxpool1)
                        for i in range(self.num_classes)
                    ])
                else:
                    self.decoder = valid_encoders[enc_type]['dec'](
                        self.enc_out_dim, self.input_height, first_conv,
                        maxpool1)
        if self.no_decoder:
            self.decoder = nn.Identity()
            for param in self.encoder.parameters():
                param.requires_grad = False

        self.mlp_mu_z = MLP(self.enc_out_dim + self.num_classes,
                            self.latent_dim,
                            hidden_dim=self.mlp_hidden_dim)
        self.mlp_var_z = MLP(self.enc_out_dim + self.num_classes,
                             self.latent_dim,
                             hidden_dim=self.mlp_hidden_dim)
        self.mlp_y = MLP(self.latent_dim,
                         self.num_classes,
                         num_layers=mlp_y_layers,
                         hidden_dim=self.mlp_hidden_dim)
        self.mlp_feat_recon = MLP(self.feat_recon_input_dim,
                                  self.enc_out_dim,
                                  hidden_dim=self.mlp_hidden_dim)

        if self.recon_loss_type == 'crf':
            # self.crf = CRF(n_spatial_dims=2, returns='proba')
            # self.crf = CRF(n_spatial_dims=2, n_iter=1)
            self.crf = CRF(n_spatial_dims=2)