Пример #1
0
    def __init__(self):

        self.net = ENet(only_encode=True)
        # # encoder_weight = torch.load('./model/encoder_ep_497_mIoU_0.5098.pth', map_location='cpu')
        model_path = '/Users/yuxuanliu/Desktop/enet.pytorch/ckpt/kitti_checkpoint_19-02-19_22-04-01_encoder_ENet_city_[320, 640]_lr_0.0005.pth'
        # model_path = '/Users/yuxuanliu/Desktop/enet.pytorch/ckpt/kitti_checkpoint_19-02-19_22-04-01_encoder_ENet_city_[320, 640]_lr_0.0005.pth'
        encoder_weight = torch.load(model_path, map_location='cpu')
        self.net.encoder.load_state_dict(encoder_weight)

        mean_std = cfg.DATA.MEAN_STD
        self.transform = standard_transforms.Compose([
            standard_transforms.ToTensor(),
            standard_transforms.Normalize(*mean_std)
        ])
Пример #2
0
class LaneNet_ENet_1E1D(nn.Module):
    ''' 
    A LaneNet model made up of ENet, please refer to 'Efficient Net'.
    1E1D means one encoder and one decoder, that is the embedding and the binary segmentation 
    share the same encoder and decoder, except the last layer of the decoder.
    '''
    def __init__(self):
        super().__init__()

        self.model = ENet(64)
        self.conv_logit = nn.Conv2d(64, 2, 1, 1)
        self.conv_embedding = nn.Conv2d(64, 3, 1, 1)

    def forward(self, input):
        x = self.model.forward(input)

        logit = self.conv_logit(x)
        embedding = self.conv_embedding(x)

        return embedding, logit
def get_model(name):
    if name == 'hlnet':
        model = HLNet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM)
    elif name == 'fastscnn':
        model = Fast_SCNN(num_classes=CLS_NUM,
                          input_shape=(IMG_SIZE, IMG_SIZE, 3)).model()
    elif name == 'lednet':
        model = LEDNet(groups=2,
                       classes=CLS_NUM,
                       input_shape=(IMG_SIZE, IMG_SIZE, 3)).model()
    elif name == 'dfanet':
        model = DFANet(input_shape=(IMG_SIZE, IMG_SIZE, 3),
                       cls_num=CLS_NUM,
                       size_factor=2)
    elif name == 'enet':
        model = ENet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM)
    elif name == 'mobilenet':
        model = MobileNet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM)
    else:
        raise NameError("No corresponding model...")

    return model
Пример #4
0
def CreateModel(args):

    if args.model == 'enet':
        model = ENet(num_classes=args.num_classes).cuda()
        optimizer = optim.Adam(
            model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

    elif args.model == 'deeplab':
        model = Deeplab(num_classes=args.num_classes).cuda()
        optimizer = None

    elif args.model == 'frrnet':
        model = FRRNet(out_channels=args.num_classes).cuda()
        optimizer = optim.Adam(
            model.parameters(), lr=args.learning_rate)

    """
    elif args.model == 'fcn8s':
        model = FCN8s(pretrained_net=None, n_class=args.num_classes).cuda()
        optimizer = optim.SGD(
            model.parameters(), lr=args.learning_rate)
    """
    
    else:
Пример #5
0
class SegmentationModel:

    WIDTH = 1240
    HEIGHT = 376

    def __init__(self):

        self.net = ENet(only_encode=True)
        # # encoder_weight = torch.load('./model/encoder_ep_497_mIoU_0.5098.pth', map_location='cpu')
        model_path = '/Users/yuxuanliu/Desktop/enet.pytorch/ckpt/kitti_checkpoint_19-02-19_22-04-01_encoder_ENet_city_[320, 640]_lr_0.0005.pth'
        # model_path = '/Users/yuxuanliu/Desktop/enet.pytorch/ckpt/kitti_checkpoint_19-02-19_22-04-01_encoder_ENet_city_[320, 640]_lr_0.0005.pth'
        encoder_weight = torch.load(model_path, map_location='cpu')
        self.net.encoder.load_state_dict(encoder_weight)

        mean_std = cfg.DATA.MEAN_STD
        self.transform = standard_transforms.Compose([
            standard_transforms.ToTensor(),
            standard_transforms.Normalize(*mean_std)
        ])

        # self.test_model()

    # def test_model(self):
    #     img = torch.zeros([5, 3, 100, 100], dtype=torch.float)
    #     seg = self.net.forward(img).data

    def prepare_image(self, img):
        pil_img = Image.fromarray(img.astype(np.uint8))
        img_tensor = self.transform(pil_img)
        img_tensor = img_tensor[:3, :SegmentationModel.
                                HEIGHT, :SegmentationModel.WIDTH]
        img_tensor = img_tensor.unsqueeze(0)

        return img_tensor

    def segment_image(self, img):

        img_tensor = self.prepare_image(img)

        seg = self.net.forward(img_tensor)
        max_vals, classes = torch.max(seg, 1)

        classes = classes.squeeze(0)

        print(torch.sum(classes == 9))

        np_classes = classes.data.numpy()
        color_mask = colorize_mask(np_classes)

        seg_img = color_mask / 255

        seg_img_bgr = seg_img[:, :, ::-1]
        cv2.imshow('img', img)
        cv2.imshow('seg', seg_img_bgr)
        cv2.waitKey(1)

        return np_classes, seg_img

    def find_seg_class(self, class_mask, position):
        mask_height, mask_width = class_mask.shape
        sample_x = int(position[0] / SegmentationModel.WIDTH * mask_width)
        sample_y = int(position[1] / SegmentationModel.HEIGHT * mask_height)

        return class_mask[sample_y, sample_x]
Пример #6
0
    def __init__(self):
        super().__init__()

        self.model = ENet(64)
        self.conv_logit = nn.Conv2d(64, 2, 1, 1)
        self.conv_embedding = nn.Conv2d(64, 3, 1, 1)
Пример #7
0
                           shuffle=True,
                           num_workers=num_workers)
 val_loader = DataLoader(val_dataset,
                         batch_size=batch_size,
                         num_workers=num_workers)
 test_loader = DataLoader(test_dataset,
                          batch_size=batch_size,
                          num_workers=num_workers)
 """Get Class Weighting"""
 class_weights = torch.from_numpy(
     weighing_class(train_loader, num_classes=12)).float().to(device)
 unlabeled_idx = list(color_encoding).index('unlabeled')
 class_weights[unlabeled_idx] = 0
 print("class_weights : ", class_weights)
 """Get model,loss,optimizer,lr_scheduler"""
 model = ENet(num_classes=12).to(device)
 loss_fn = nn.CrossEntropyLoss(weight=class_weights)
 optimizer = torch.optim.Adam(model.parameters(),
                              lr=lr,
                              weight_decay=weight_decay)
 lr_scheduler = lr_scheduler.StepLR(optimizer, lr_decay_epochs, lr_decay)
 """Training"""
 best_val_iou = 0
 for epoch in range(epochs):
     #Train
     train_loss = train_batch(train_loader,
                              model,
                              loss_fn,
                              optimizer,
                              device=device)
     lr_scheduler.step()
Пример #8
0
def prepare_network():
    ENet_model = ENet(len(cityscapes_labels))
    checkpoint = torch.load('model/ENet')
    ENet_model.load_state_dict(checkpoint['state_dict'])
    return ENet_model.eval().cuda()