Exemplo n.º 1
0
def main(config, checkpoint_path):
    datagen = CocoBboxDatagen(config)
    vgg16 = models.vgg16(pretrained=True)
    model = SegNet(2, 3)
    model.init_vgg16_params(vgg16)
    trainer = Trainer(datagen, model, config)
    trainer.run(checkpoint_path)
    def __init__(self, in_channels=12, use_model='unet', use_d8=False,
                 learning_rate=0.02, adam_epsilon=1e-8, **kwargs):
        super(DrainageNetworkExtractor, self).__init__()
        self.save_hyperparameters()

        if use_model.lower() == 'unet':
            self.model = UNet(n_channels=in_channels, n_classes=12, bilinear=self.hparams.bilinear)
        elif use_model.lower() == 'lhn_unet':
            self.model = LHNUNet(n_channels=in_channels, n_classes=12,
                                 n_classes_l1=self.hparams.n_classes_l1, n_classes_l2=self.hparams.n_classes_l2,
                                 n_classes_l3=self.hparams.n_classes_l3, n_classes_l4=self.hparams.n_classes_l4)
        elif use_model.lower() == 'deep_lab':
            self.model = DeepLab(backbone=self.hparams.backbone, in_channels=in_channels, num_classes=12,
                                 sync_bn=self.hparams.sync_bn, freeze_bn=self.hparams.freeze_bn,
                                 output_stride=self.hparams.output_stride)
        elif use_model.lower() == 'modsegnet':
            self.model = ModSegNet(num_classes=12, n_init_features=in_channels, drop_rate=self.hparams.drop_rate)
        elif use_model.lower() == 'segnet':
            self.model = SegNet(num_classes=12, n_init_features=in_channels, drop_rate=self.hparams.drop_rate,
                                use_kriging_loss=self.hparams.use_kriging_loss)
        elif use_model.lower() == 'aspp_segnet':
            self.model = ASPPSegNet(num_classes=12, n_init_features=in_channels,
                                    use_kriging_loss=self.hparams.use_kriging_loss)
        elif use_model.lower() == 'sp_segnet':
            self.model = SPSegNet(num_classes=12, n_init_features=in_channels)
        elif use_model.lower() == 'dl_segnet':
            self.model = DLSegNet(num_classes=12, n_init_features=in_channels,
                                  drop_rate=self.hparams.drop_rate)
        else:
            raise Exception(f"{use_model} is not implemented")

        if use_d8:
            self.d8_emb = nn.Embedding(9, 3, max_norm=1)
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--use-coord', action='store_true')
        parser.add_argument('--use-d8', action='store_true')
        parser.add_argument('--use-slope', action='store_true')
        parser.add_argument('--use-curvature', action='store_true')
        parser.add_argument('--in-channels', type=int, default=12)
        parser.add_argument('--use-model', type=str, default='unet')
        parser.add_argument('--learning-rate', type=float, default=0.02)
        parser.add_argument('--adam-epsilon', type=float, default=1e-8)
        parser.add_argument('--use-kriging-loss', action='store_true')
        parser.fromfile_prefix_chars = "@"

        temp_args, _ = parser.parse_known_args()
        if temp_args.use_model.lower() == 'unet':
            parser = UNet.add_model_specific_args(parser)
        elif temp_args.use_model.lower() == 'lhn_unet':
            parser = LHNUNet.add_model_specific_args(parser)
        elif temp_args.use_model.lower() == 'deep_lab':
            parser = DeepLab.add_model_specific_args(parser)
        elif temp_args.use_model.lower() == 'modsegnet':
            parser = ModSegNet.add_model_specific_args(parser)
        elif temp_args.use_model.lower() == 'segnet':
            parser = SegNet.add_model_specific_args(parser)
        elif temp_args.use_model.lower() == 'aspp_segnet':
            parser = ASPPSegNet.add_model_specific_args(parser)
        elif temp_args.use_model.lower() == 'sp_segnet':
            parser = SPSegNet.add_model_specific_args(parser)
        elif temp_args.use_model.lower() == 'dl_segnet':
            parser = DLSegNet.add_model_specific_args(parser)

        return parser
Exemplo n.º 4
0
    def __init__(self, seg_path):

        '''

        Loads the SegNet model with auggi trained weights

        '''

        N_LABELS = 1 # Binary Classifier

        # Segmentation model needs to load here and set to eval
        self.model = SegNet(input_nbr=3, label_nbr=N_LABELS)
        self.model.load_from_filename(seg_path) # load auggi trained weights
        self.model.eval() # set to eval mode

        # Device configuration
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)  # move to right device
Exemplo n.º 5
0
def get_model(model_name, input_channels, class_num):

    if model_name == 'unet':
        from models.unet import UNet
        net = UNet(input_channels, class_num)

    elif model_name == 'segnet':
        from models.segnet import SegNet
        net = SegNet(input_channels, class_num)

    else:
        raise ValueError('network type does not supported')

    return net
Exemplo n.º 6
0
class Segmentor:

    '''

    Instantiates a SegNet model, loads the trained auggi model, and can process
    a PIL image, returning a binary mask.

    '''


    def __init__(self, seg_path):

        '''

        Loads the SegNet model with auggi trained weights

        '''

        N_LABELS = 1 # Binary Classifier

        # Segmentation model needs to load here and set to eval
        self.model = SegNet(input_nbr=3, label_nbr=N_LABELS)
        self.model.load_from_filename(seg_path) # load auggi trained weights
        self.model.eval() # set to eval mode

        # Device configuration
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)  # move to right device

    def segment(self, image):
        '''
        Needs to receive a RGB PIL image, returns a PIL image binary mask

        '''

        # reshape, convert to tensor and normalize by ImageNet values
        resize = transforms.Resize((224, 224))
        to_tensor = transforms.ToTensor()
        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

        # define transforms
        image = resize(image)
        image = to_tensor(image)
        image = normalize(image)

        image = torch.unsqueeze(image, 0)  # add 1 for batch size

        # put image on appropriate device
        image = image.to(self.device)
        
        mask = self.model(image)  # forward prop

        mask_as_img = F.to_pil_image(mask[0])  # don't forget to grab the first entry of mask (4d tensor)

        # return the binary mask PIL image
        return mask_as_img
Exemplo n.º 7
0
Arquivo: model.py Projeto: kqf/hubmap
def build_model(max_epochs=2, logdir=".tmp/", train_split=None):
    scheduler = skorch.callbacks.LRScheduler(
        policy=torch.optim.lr_scheduler.CyclicLR,
        base_lr=0.00001,
        max_lr=0.4,
        step_size_up=1900,
        step_size_down=3900,
        step_every='batch',
    )

    model = SegNet(
        ResUNet,
        module__pretrained=False,
        criterion=BCEWithLogitsLossPadding,
        criterion__padding=0,
        batch_size=32,
        max_epochs=max_epochs,
        # optimizer__momentum=0.9,
        iterator_train__shuffle=True,
        iterator_train__num_workers=4,
        iterator_valid__shuffle=False,
        iterator_valid__num_workers=4,
        train_split=train_split,
        callbacks=[
            skorch.callbacks.ProgressBar(),
            skorch.callbacks.EpochScoring(score,
                                          name='iou',
                                          lower_is_better=False),
            TensorBoardWithImages(SummaryWriter(logdir)),
            skorch.callbacks.Checkpoint(dirname=logdir),
            skorch.callbacks.TrainEndCheckpoint(dirname=logdir),
            scheduler,
            skorch.callbacks.Initializer("*", init),
        ],
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    )

    return model
    #     list_file=train_path, img_dir=img_dir, mask_dir=mask_dir
    # )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=4,
        # pin_memory=True,
    )

    if args.model == "unet":
        model = UNet(input_channels=NUM_INPUT_CHANNELS,
                     output_channels=NUM_OUTPUT_CHANNELS)
    elif args.model == "segnet":
        model = SegNet(input_channels=NUM_INPUT_CHANNELS,
                       output_hannels=NUM_OUTPUT_CHANNELS)
    else:
        model = PSPNet(
            layers=50,
            bins=(1, 2, 3, 6),
            dropout=0.1,
            classes=NUM_OUTPUT_CHANNELS,
            use_ppm=True,
            pretrained=True,
        )

    # class_weights = 1.0 / train_dataset.get_class_probability()
    # criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
    criterion = torch.nn.CrossEntropyLoss()

    if CUDA:
Exemplo n.º 9
0
from common import args

from models.segnet import SegNet
from models.fcn.fcn import Model as FCNet

if args.model == 'segnet':
    model = SegNet()
    if args.test:
        model.test()
    else:
        model.train()

elif args.model == 'fcn':
    model = FCNet(session, network, config["categories"])
Exemplo n.º 10
0
                              shuffle=True)
test_dataset = AuggiDetectionDataset(
    mode="TEST",
    size=IMG_SHAPE,
    use_coco_dataset=USE_COCO_DATASET,
    use_augmentation=False,
    augmentation_uniform_threshold=DATA_AUGMENTATION_UNIFORM_RANDOM_THRESHOLD)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

# CHECK CUDA AVAILABILITY
CUDA_AVAILABLE = torch.cuda.is_available()

##
## INSTANTIATE MODEL
##
model = SegNet(input_nbr=3, label_nbr=N_LABELS)
model.load_from_filename(MODEL_PATH)  # load segnet weights
model.eval()  # set to eval mode

if CUDA_AVAILABLE:  # convert to cuda if needed
    model.cuda()
else:
    model.float()

# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


def denormalize(img):
    # takes 3 dim tensor and denormalizes