예제 #1
0
class ENetDemo(ImageInferEngine):
    def __init__(self, f, model_path):
        super(ENetDemo, self).__init__(f=f)

        self.target_size = (512, 1024)
        self.model_path = model_path
        self.num_classes = 20

        self.image_transform = transforms.Compose(
            [transforms.Resize(self.target_size),
             transforms.ToTensor()])

        self._init_model()

    def _init_model(self):
        self.model = ENet(self.num_classes).to(device)
        checkpoint = torch.load(self.model_path)
        self.model.load_state_dict(checkpoint['state_dict'])
        print('Model loaded!')

    def solve_a_image(self, img):
        images = Variable(
            self.image_transform(Image.fromarray(img)).to(device).unsqueeze(0))
        predictions = self.model(images)
        _, predictions = torch.max(predictions.data, 1)
        prediction = predictions.cpu().numpy()[0] - 1
        return prediction

    def vis_result(self, img, net_out):
        mask_color = np.asarray(label_to_color_image(net_out, 'cityscapes'),
                                dtype=np.uint8)
        frame = cv2.resize(img, (self.target_size[1], self.target_size[0]))
        # mask_color = cv2.resize(mask_color, (frame.shape[1], frame.shape[0]))
        res = cv2.addWeighted(frame, 0.5, mask_color, 0.7, 1)
        return res
예제 #2
0
import torch
from models.enet import ENet
import os
from configs import config_factory

if __name__ == '__main__':
    save_pth = os.path.join(config_factory['resnet_cityscapes'].respth,
                            'model_final.pth')
    model = ENet(nb_classes=19)
    model.load_state_dict(torch.load(save_pth))
    model.eval()
    example = torch.rand(2, 3, 1024, 1024).cpu()
    traced_script_module = torch.jit.trace(model, example)
    traced_script_module.save(
        os.path.join(config_factory['resnet_cityscapes'].respth,
                     "model_dfanet_1024.pt"))
예제 #3
0
    ('sidewalk', (244, 35, 232)), ('building', (70, 70, 70)),
    ('wall', (102, 102, 156)), ('fence', (190, 153, 153)),
    ('pole', (153, 153, 153)), ('traffic_light', (250, 170, 30)),
    ('traffic_sign', (220, 220, 0)), ('vegetation', (107, 142, 35)),
    ('terrain', (152, 251, 152)), ('sky', (70, 130, 180)),
    ('person', (220, 20, 60)), ('rider', (255, 0, 0)), ('car', (0, 0, 142)),
    ('truck', (0, 0, 70)), ('bus', (0, 60, 100)), ('train', (0, 80, 100)),
    ('motorcycle', (0, 0, 230)), ('bicycle', (119, 11, 32))
])
num_classes = len(color_encoding)
model = ENet(num_classes).to(device)

# Load the pre-trained weights
model_path = "./save/ENet_Cityscapes/ENet"
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['state_dict'])
print('Model loaded successfully!')

# Run the inference
# If args.test, then showcase how this model works
if not args.test:
    model.eval()
    sample_image = torch.unsqueeze(sample_image, 0)
    with torch.no_grad():
        output = model(sample_image)
    print("Model output dimension:", output.shape)

    # Convert it to a single int using the indices where the maximum (1) occurs
    _, predictions = torch.max(output.data, 1)

    label_to_rgb = transforms.Compose([
예제 #4
0
파일: main.py 프로젝트: shguan10/cat_coop
def train(train_loader,
          val_loader,
          class_weights,
          class_encoding,
          pretrained="./save/ENet.pt"):
    print("\nTraining...\n")

    num_classes = len(class_encoding)

    model = ENet(num_classes)

    if pretrained:
        model.load_state_dict(torch.load(pretrained)["state_dict"])

    # Intialize ENet
    model = model.to(device)
    # Check if the network architecture is correct
    print(model)

    # We are going to use the CrossEntropyLoss loss function as it's most
    # frequentely used in classification problems with multiple classes which
    # fits the problem. This criterion  combines LogSoftMax and NLLLoss.
    criterion = nn.CrossEntropyLoss(weight=class_weights)

    # ENet authors used Adam as the optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=args.learning_rate,
                           weight_decay=args.weight_decay)

    # Learning rate decay scheduler
    lr_updater = lr_scheduler.StepLR(optimizer, args.lr_decay_epochs,
                                     args.lr_decay)

    # Evaluation metric
    if args.ignore_unlabeled:
        ignore_index = list(class_encoding).index('unlabeled')
    else:
        ignore_index = None
    metric = IoU(num_classes, ignore_index=ignore_index)

    # Optionally resume from a checkpoint
    if args.resume:
        model, optimizer, start_epoch, best_miou = utils.load_checkpoint(
            model, optimizer, args.save_dir, args.name)
        print("Resuming from model: Start epoch = {0} "
              "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou))
    else:
        start_epoch = 0
        best_miou = 0

    # Start Training
    train = Train(model, train_loader, optimizer, criterion, metric, device)
    val = Test(model, val_loader, criterion, metric, device)
    for epoch in range(start_epoch, args.epochs):
        print(">>>> [Epoch: {0:d}] Training".format(epoch))

        lr_updater.step()
        epoch_loss, (iou, miou) = train.run_epoch(args.print_step)

        print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
              format(epoch, epoch_loss, miou))

        if epoch % 10 == 0 or epoch + 1 == args.epochs:
            print(">>>> [Epoch: {0:d}] Validation".format(epoch))

            loss, (iou, miou) = val.run_epoch(args.print_step)

            print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
                  format(epoch, loss, miou))

            # Print per class IoU on last epoch or if best iou
            if epoch + 1 == args.epochs or miou > best_miou:
                for key, class_iou in zip(class_encoding.keys(), iou):
                    print("{0}: {1:.4f}".format(key, class_iou))

            # Save the model if it's the best thus far
            if miou > best_miou:
                print("\nBest model thus far. Saving...\n")
                best_miou = miou
                utils.save_checkpoint(model, optimizer, epoch + 1, best_miou,
                                      args)

    return model