Ejemplo n.º 1
0
    def __init__(self):

        self.net = ENet(only_encode=True)
        # # encoder_weight = torch.load('./model/encoder_ep_497_mIoU_0.5098.pth', map_location='cpu')
        model_path = '/Users/yuxuanliu/Desktop/enet.pytorch/ckpt/kitti_checkpoint_19-02-19_22-04-01_encoder_ENet_city_[320, 640]_lr_0.0005.pth'
        # model_path = '/Users/yuxuanliu/Desktop/enet.pytorch/ckpt/kitti_checkpoint_19-02-19_22-04-01_encoder_ENet_city_[320, 640]_lr_0.0005.pth'
        encoder_weight = torch.load(model_path, map_location='cpu')
        self.net.encoder.load_state_dict(encoder_weight)

        mean_std = cfg.DATA.MEAN_STD
        self.transform = standard_transforms.Compose([
            standard_transforms.ToTensor(),
            standard_transforms.Normalize(*mean_std)
        ])
Ejemplo n.º 2
0
class LaneNet_ENet_1E1D(nn.Module):
    ''' 
    A LaneNet model made up of ENet, please refer to 'Efficient Net'.
    1E1D means one encoder and one decoder, that is the embedding and the binary segmentation 
    share the same encoder and decoder, except the last layer of the decoder.
    '''
    def __init__(self):
        super().__init__()

        self.model = ENet(64)
        self.conv_logit = nn.Conv2d(64, 2, 1, 1)
        self.conv_embedding = nn.Conv2d(64, 3, 1, 1)

    def forward(self, input):
        x = self.model.forward(input)

        logit = self.conv_logit(x)
        embedding = self.conv_embedding(x)

        return embedding, logit
def get_model(name):
    if name == 'hlnet':
        model = HLNet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM)
    elif name == 'fastscnn':
        model = Fast_SCNN(num_classes=CLS_NUM,
                          input_shape=(IMG_SIZE, IMG_SIZE, 3)).model()
    elif name == 'lednet':
        model = LEDNet(groups=2,
                       classes=CLS_NUM,
                       input_shape=(IMG_SIZE, IMG_SIZE, 3)).model()
    elif name == 'dfanet':
        model = DFANet(input_shape=(IMG_SIZE, IMG_SIZE, 3),
                       cls_num=CLS_NUM,
                       size_factor=2)
    elif name == 'enet':
        model = ENet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM)
    elif name == 'mobilenet':
        model = MobileNet(input_shape=(IMG_SIZE, IMG_SIZE, 3), cls_num=CLS_NUM)
    else:
        raise NameError("No corresponding model...")

    return model
Ejemplo n.º 4
0
def CreateModel(args):

    if args.model == 'enet':
        model = ENet(num_classes=args.num_classes).cuda()
        optimizer = optim.Adam(
            model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

    elif args.model == 'deeplab':
        model = Deeplab(num_classes=args.num_classes).cuda()
        optimizer = None

    elif args.model == 'frrnet':
        model = FRRNet(out_channels=args.num_classes).cuda()
        optimizer = optim.Adam(
            model.parameters(), lr=args.learning_rate)

    """
    elif args.model == 'fcn8s':
        model = FCN8s(pretrained_net=None, n_class=args.num_classes).cuda()
        optimizer = optim.SGD(
            model.parameters(), lr=args.learning_rate)
    """
    
    else:
Ejemplo n.º 5
0
class SegmentationModel:

    WIDTH = 1240
    HEIGHT = 376

    def __init__(self):

        self.net = ENet(only_encode=True)
        # # encoder_weight = torch.load('./model/encoder_ep_497_mIoU_0.5098.pth', map_location='cpu')
        model_path = '/Users/yuxuanliu/Desktop/enet.pytorch/ckpt/kitti_checkpoint_19-02-19_22-04-01_encoder_ENet_city_[320, 640]_lr_0.0005.pth'
        # model_path = '/Users/yuxuanliu/Desktop/enet.pytorch/ckpt/kitti_checkpoint_19-02-19_22-04-01_encoder_ENet_city_[320, 640]_lr_0.0005.pth'
        encoder_weight = torch.load(model_path, map_location='cpu')
        self.net.encoder.load_state_dict(encoder_weight)

        mean_std = cfg.DATA.MEAN_STD
        self.transform = standard_transforms.Compose([
            standard_transforms.ToTensor(),
            standard_transforms.Normalize(*mean_std)
        ])

        # self.test_model()

    # def test_model(self):
    #     img = torch.zeros([5, 3, 100, 100], dtype=torch.float)
    #     seg = self.net.forward(img).data

    def prepare_image(self, img):
        pil_img = Image.fromarray(img.astype(np.uint8))
        img_tensor = self.transform(pil_img)
        img_tensor = img_tensor[:3, :SegmentationModel.
                                HEIGHT, :SegmentationModel.WIDTH]
        img_tensor = img_tensor.unsqueeze(0)

        return img_tensor

    def segment_image(self, img):

        img_tensor = self.prepare_image(img)

        seg = self.net.forward(img_tensor)
        max_vals, classes = torch.max(seg, 1)

        classes = classes.squeeze(0)

        print(torch.sum(classes == 9))

        np_classes = classes.data.numpy()
        color_mask = colorize_mask(np_classes)

        seg_img = color_mask / 255

        seg_img_bgr = seg_img[:, :, ::-1]
        cv2.imshow('img', img)
        cv2.imshow('seg', seg_img_bgr)
        cv2.waitKey(1)

        return np_classes, seg_img

    def find_seg_class(self, class_mask, position):
        mask_height, mask_width = class_mask.shape
        sample_x = int(position[0] / SegmentationModel.WIDTH * mask_width)
        sample_y = int(position[1] / SegmentationModel.HEIGHT * mask_height)

        return class_mask[sample_y, sample_x]
Ejemplo n.º 6
0
    def __init__(self):
        super().__init__()

        self.model = ENet(64)
        self.conv_logit = nn.Conv2d(64, 2, 1, 1)
        self.conv_embedding = nn.Conv2d(64, 3, 1, 1)
Ejemplo n.º 7
0
                           shuffle=True,
                           num_workers=num_workers)
 val_loader = DataLoader(val_dataset,
                         batch_size=batch_size,
                         num_workers=num_workers)
 test_loader = DataLoader(test_dataset,
                          batch_size=batch_size,
                          num_workers=num_workers)
 """Get Class Weighting"""
 class_weights = torch.from_numpy(
     weighing_class(train_loader, num_classes=12)).float().to(device)
 unlabeled_idx = list(color_encoding).index('unlabeled')
 class_weights[unlabeled_idx] = 0
 print("class_weights : ", class_weights)
 """Get model,loss,optimizer,lr_scheduler"""
 model = ENet(num_classes=12).to(device)
 loss_fn = nn.CrossEntropyLoss(weight=class_weights)
 optimizer = torch.optim.Adam(model.parameters(),
                              lr=lr,
                              weight_decay=weight_decay)
 lr_scheduler = lr_scheduler.StepLR(optimizer, lr_decay_epochs, lr_decay)
 """Training"""
 best_val_iou = 0
 for epoch in range(epochs):
     #Train
     train_loss = train_batch(train_loader,
                              model,
                              loss_fn,
                              optimizer,
                              device=device)
     lr_scheduler.step()
Ejemplo n.º 8
0
def prepare_network():
    ENet_model = ENet(len(cityscapes_labels))
    checkpoint = torch.load('model/ENet')
    ENet_model.load_state_dict(checkpoint['state_dict'])
    return ENet_model.eval().cuda()