Beispiel #1
0
def single():
    print('Mode: Single')
    img = Image.open('test_content/example_01.png').convert('RGB')

    class_encoding = color_encoding = OrderedDict([
        ('unlabeled', (0, 0, 0)), ('road', (128, 64, 128)),
        ('sidewalk', (244, 35, 232)), ('building', (70, 70, 70)),
        ('wall', (102, 102, 156)), ('fence', (190, 153, 153)),
        ('pole', (153, 153, 153)), ('traffic_light', (250, 170, 30)),
        ('traffic_sign', (220, 220, 0)), ('vegetation', (107, 142, 35)),
        ('terrain', (152, 251, 152)), ('sky', (70, 130, 180)),
        ('person', (220, 20, 60)), ('rider', (255, 0, 0)),
        ('car', (0, 0, 142)), ('truck', (0, 0, 70)), ('bus', (0, 60, 100)),
        ('train', (0, 80, 100)), ('motorcycle', (0, 0, 230)),
        ('bicycle', (119, 11, 32))
    ])

    num_classes = len(class_encoding)
    model = ERFNet(num_classes)
    model_path = os.path.join(args.save_dir, args.name)
    print('Loading model at:', model_path)
    checkpoint = torch.load(model_path)
    # model = ENet(num_classes)
    model = model.cuda()
    model.load_state_dict(checkpoint['state_dict'])
    img = img.resize((args.width, args.height), Image.BILINEAR)
    start = time.time()
    images = transforms.ToTensor()(img)
    torch.reshape(images, (1, 3, args.width, args.height))
    images = images.unsqueeze(0)
    with torch.no_grad():
        images = images.cuda()
        predictions = model(images)
        end = time.time()
        print('model speed:', int(1 / (end - start)), "FPS")
        _, predictions = torch.max(predictions.data, 1)
        label_to_rgb = transforms.Compose(
            [utils.LongTensorToRGBPIL(class_encoding),
             transforms.ToTensor()])
        color_predictions = utils.batch_transform(predictions.cpu(),
                                                  label_to_rgb)
        end = time.time()
        print('model+transform:', int(1 / (end - start)), "FPS")
        utils.imshow_batch(images.data.cpu(), color_predictions)
Beispiel #2
0
def lane_detect(im_tensor):
    # Image size
    _, HEIGHT, WIDTH = im_tensor.shape
    im_tensor = im_tensor.unsqueeze(0)

    # Creating CNNs and loading pretrained models
    segmentation_network = ERFNet(NUM_CLASSES_SEGMENTATION)
    classification_network = LCNet(NUM_CLASSES_CLASSIFICATION, DESCRIPTOR_SIZE,
                                   DESCRIPTOR_SIZE)

    segmentation_network.load_state_dict(
        torch.load(path + 'pretrained/erfnet_tusimple.pth',
                   map_location=map_location))
    model_path = path + 'pretrained/classification_{}_{}class.pth'.format(
        DESCRIPTOR_SIZE, NUM_CLASSES_CLASSIFICATION)
    classification_network.load_state_dict(
        torch.load(model_path, map_location=map_location))

    segmentation_network = segmentation_network.eval()
    classification_network = classification_network.eval()

    if torch.cuda.is_available():
        segmentation_network = segmentation_network.cuda()
        classification_network = classification_network.cuda()
        im_tensor = im_tensor.cuda()

    out_segmentation = segmentation_network(im_tensor)
    out_segmentation = out_segmentation.max(dim=1)[1]

    out_segmentation_np = out_segmentation.cpu().numpy()[0]
    descriptors, index_map = extract_descriptors(out_segmentation, im_tensor)
    classes = classification_network(descriptors).max(1)[1]

    lane_map = torch.zeros(HEIGHT, WIDTH, dtype=torch.int64)
    if torch.cuda.is_available():
        lane_map = lane_map.cuda()
    for i, lane_index in index_map.items():
        lane_map[out_segmentation_np == lane_index] = classes[i] + 1

    return lane_map
Beispiel #3
0
    def __init__(self, options, model=None):

        if __name__ == "__main__":
            print(" -> Executing script", os.path.basename(__file__))

        self.opt = options
        self.device = torch.device("cpu" if self.opt.no_cuda else "cuda")

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           LABELS
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        assert self.opt.train_set in {1, 2, 3, 12, 123}, "Invalid train_set!"
        assert self.opt.task_to_val in {0, 1, 2, 3, 12, 123}, "Invalid task!"
        keys_to_load = ['color', 'segmentation']

        # Labels
        labels = self._get_labels_cityscapes()

        # Train IDs
        self.train_ids = set([labels[i].trainId for i in range(len(labels))])
        self.train_ids.remove(255)
        self.train_ids = sorted(list(self.train_ids))

        self.num_classes_model = len(self.train_ids)

        # Task handling
        if self.opt.task_to_val != 0:
            labels_task = self._get_task_labels_cityscapes()
            train_ids_task = set(
                [labels_task[i].trainId for i in range(len(labels_task))])
            train_ids_task.remove(255)
            self.task_low = min(train_ids_task)
            self.task_high = max(train_ids_task) + 1
            labels = labels_task
            self.train_ids = sorted(list(train_ids_task))
        else:
            self.task_low = 0
            self.task_high = self.num_classes_model
            self.opt.task_to_val = self.opt.train_set

        # Number of classes for the SegmentationRunningScore
        self.num_classes_score = self.task_high - self.task_low

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           DATASET DEFINITIONS
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # Data augmentation
        test_data_transforms = [
            mytransforms.CreateScaledImage(),
            mytransforms.Resize((self.opt.height, self.opt.width),
                                image_types=['color']),
            mytransforms.ConvertSegmentation(),
            mytransforms.CreateColoraug(new_element=True,
                                        scales=self.opt.scales),
            mytransforms.RemoveOriginals(),
            mytransforms.ToTensor(),
            mytransforms.NormalizeZeroMean(),
        ]

        # If hyperparameter search, only load the respective validation set. Else, load the full validation set.
        if self.opt.hyperparameter:
            trainvaltest_split = 'train'
            folders_to_load = CitySet.get_city_set(-1)
        else:
            trainvaltest_split = 'validation'
            folders_to_load = None

        test_dataset = CityscapesDataset(dataset='cityscapes',
                                         split=self.opt.dataset_split,
                                         trainvaltest_split=trainvaltest_split,
                                         video_mode='mono',
                                         stereo_mode='mono',
                                         scales=self.opt.scales,
                                         labels_mode='fromid',
                                         labels=labels,
                                         keys_to_load=keys_to_load,
                                         data_transforms=test_data_transforms,
                                         video_frames=self.opt.video_frames,
                                         folders_to_load=folders_to_load)

        self.test_loader = DataLoader(dataset=test_dataset,
                                      batch_size=self.opt.batch_size,
                                      shuffle=False,
                                      num_workers=self.opt.num_workers,
                                      pin_memory=True,
                                      drop_last=False)

        print(
            "++++++++++++++++++++++ INIT VALIDATION ++++++++++++++++++++++++")
        print("Using dataset\n  ", self.opt.dataset, "with split",
              self.opt.dataset_split)
        print("There are {:d} validation items\n  ".format(len(test_dataset)))
        print("Validating classes up to train set\n  ", self.opt.train_set)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           LOGGING OPTIONS
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # If no model is passed, standalone validation is to be carried out. The log_path needs to be set before
        # self.load_model() is invoked.
        if model is None:
            self.opt.validate = False
            self.opt.model_name = self.opt.load_model_name

        path_getter = GetPath()
        log_path = path_getter.get_checkpoint_path()
        self.log_path = os.path.join(log_path, 'erfnet', self.opt.model_name)

        # All outputs will be saved to save_path
        self.save_path = self.log_path

        # Create output path for standalone validation
        if not self.opt.validate:
            save_dir = 'eval_{}'.format(self.opt.dataset)

            if self.opt.hyperparameter:
                save_dir = save_dir + '_hyper'

            save_dir = save_dir + '_task_to_val{}'.format(self.opt.task_to_val)

            self.save_path = os.path.join(self.log_path, save_dir)

            if not os.path.exists(self.save_path):
                os.makedirs(self.save_path)

        # Copy this file to save_path
        shutil.copy2(__file__, self.save_path)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           MODEL DEFINITION
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # Standalone validation
        if not self.opt.validate:
            # Create a conventional ERFNet
            self.model = ERFNet(self.num_classes_model, self.opt)
            self.load_model()
            self.model.to(self.device)

        # Validate while training
        else:
            self.model = model

        self.model.eval()

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           LOGGING OPTIONS II
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # self.called is used to decide which file mode shall be used when writing metrics to disk.
        self.called = False

        self.metric_model = SegmentationRunningScore(self.num_classes_score)

        # Metrics are only saved if val_frequency > 0!
        if self.opt.val_frequency != 0:
            print("Saving metrics to\n  ", self.save_path)

        # Set up colour output. Coloured images are only output if standalone validation is carried out!
        if not self.opt.validate and self.opt.save_pred_to_disk:
            # Output path
            self.img_path = os.path.join(
                self.save_path, 'output_{}'.format(self.opt.weights_epoch))

            if self.opt.pred_wout_blend:
                self.img_path += '_wout_blend'

            if not os.path.exists(self.img_path):
                os.makedirs(self.img_path)
            print("Saving prediction images to\n  ", self.img_path)
            print("Save frequency\n  ", self.opt.pred_frequency)

            # Get the colours from dataset.
            colors = [
                (label.trainId - self.task_low, label.color)
                for label in labels
                if label.trainId != 255 and label.trainId in self.train_ids
            ]
            colors.append((255, (0, 0, 0)))  # void class
            self.id_color = dict(colors)
            self.id_color_keys = [key for key in self.id_color.keys()]
            self.id_color_vals = [val for val in self.id_color.values()]

            # Ongoing index to name the outputs
            self.img_idx = 0

        # Set up probability output. Probabilities are only output if standalone validation is carried out!
        if not self.opt.validate and self.opt.save_probs_to_disk:
            # Output path
            self.logit_path = os.path.join(
                self.save_path,
                'probabilities_{}'.format(self.opt.weights_epoch))
            if not os.path.exists(self.logit_path):
                os.makedirs(self.logit_path)
            print("Saving probabilities to\n  ", self.logit_path)
            print("Save frequency\n  ", self.opt.probs_frequency)

            # Ongoing index to name the probability outputs
            self.probs_idx = 0

        print(
            "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

        # Save all options to disk and print them to stdout
        self._print_options()
        self._save_opts(len(test_dataset))
Beispiel #4
0
class Evaluator:
    def __init__(self, options, model=None):

        if __name__ == "__main__":
            print(" -> Executing script", os.path.basename(__file__))

        self.opt = options
        self.device = torch.device("cpu" if self.opt.no_cuda else "cuda")

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           LABELS
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        assert self.opt.train_set in {1, 2, 3, 12, 123}, "Invalid train_set!"
        assert self.opt.task_to_val in {0, 1, 2, 3, 12, 123}, "Invalid task!"
        keys_to_load = ['color', 'segmentation']

        # Labels
        labels = self._get_labels_cityscapes()

        # Train IDs
        self.train_ids = set([labels[i].trainId for i in range(len(labels))])
        self.train_ids.remove(255)
        self.train_ids = sorted(list(self.train_ids))

        self.num_classes_model = len(self.train_ids)

        # Task handling
        if self.opt.task_to_val != 0:
            labels_task = self._get_task_labels_cityscapes()
            train_ids_task = set(
                [labels_task[i].trainId for i in range(len(labels_task))])
            train_ids_task.remove(255)
            self.task_low = min(train_ids_task)
            self.task_high = max(train_ids_task) + 1
            labels = labels_task
            self.train_ids = sorted(list(train_ids_task))
        else:
            self.task_low = 0
            self.task_high = self.num_classes_model
            self.opt.task_to_val = self.opt.train_set

        # Number of classes for the SegmentationRunningScore
        self.num_classes_score = self.task_high - self.task_low

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           DATASET DEFINITIONS
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # Data augmentation
        test_data_transforms = [
            mytransforms.CreateScaledImage(),
            mytransforms.Resize((self.opt.height, self.opt.width),
                                image_types=['color']),
            mytransforms.ConvertSegmentation(),
            mytransforms.CreateColoraug(new_element=True,
                                        scales=self.opt.scales),
            mytransforms.RemoveOriginals(),
            mytransforms.ToTensor(),
            mytransforms.NormalizeZeroMean(),
        ]

        # If hyperparameter search, only load the respective validation set. Else, load the full validation set.
        if self.opt.hyperparameter:
            trainvaltest_split = 'train'
            folders_to_load = CitySet.get_city_set(-1)
        else:
            trainvaltest_split = 'validation'
            folders_to_load = None

        test_dataset = CityscapesDataset(dataset='cityscapes',
                                         split=self.opt.dataset_split,
                                         trainvaltest_split=trainvaltest_split,
                                         video_mode='mono',
                                         stereo_mode='mono',
                                         scales=self.opt.scales,
                                         labels_mode='fromid',
                                         labels=labels,
                                         keys_to_load=keys_to_load,
                                         data_transforms=test_data_transforms,
                                         video_frames=self.opt.video_frames,
                                         folders_to_load=folders_to_load)

        self.test_loader = DataLoader(dataset=test_dataset,
                                      batch_size=self.opt.batch_size,
                                      shuffle=False,
                                      num_workers=self.opt.num_workers,
                                      pin_memory=True,
                                      drop_last=False)

        print(
            "++++++++++++++++++++++ INIT VALIDATION ++++++++++++++++++++++++")
        print("Using dataset\n  ", self.opt.dataset, "with split",
              self.opt.dataset_split)
        print("There are {:d} validation items\n  ".format(len(test_dataset)))
        print("Validating classes up to train set\n  ", self.opt.train_set)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           LOGGING OPTIONS
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # If no model is passed, standalone validation is to be carried out. The log_path needs to be set before
        # self.load_model() is invoked.
        if model is None:
            self.opt.validate = False
            self.opt.model_name = self.opt.load_model_name

        path_getter = GetPath()
        log_path = path_getter.get_checkpoint_path()
        self.log_path = os.path.join(log_path, 'erfnet', self.opt.model_name)

        # All outputs will be saved to save_path
        self.save_path = self.log_path

        # Create output path for standalone validation
        if not self.opt.validate:
            save_dir = 'eval_{}'.format(self.opt.dataset)

            if self.opt.hyperparameter:
                save_dir = save_dir + '_hyper'

            save_dir = save_dir + '_task_to_val{}'.format(self.opt.task_to_val)

            self.save_path = os.path.join(self.log_path, save_dir)

            if not os.path.exists(self.save_path):
                os.makedirs(self.save_path)

        # Copy this file to save_path
        shutil.copy2(__file__, self.save_path)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           MODEL DEFINITION
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # Standalone validation
        if not self.opt.validate:
            # Create a conventional ERFNet
            self.model = ERFNet(self.num_classes_model, self.opt)
            self.load_model()
            self.model.to(self.device)

        # Validate while training
        else:
            self.model = model

        self.model.eval()

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           LOGGING OPTIONS II
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # self.called is used to decide which file mode shall be used when writing metrics to disk.
        self.called = False

        self.metric_model = SegmentationRunningScore(self.num_classes_score)

        # Metrics are only saved if val_frequency > 0!
        if self.opt.val_frequency != 0:
            print("Saving metrics to\n  ", self.save_path)

        # Set up colour output. Coloured images are only output if standalone validation is carried out!
        if not self.opt.validate and self.opt.save_pred_to_disk:
            # Output path
            self.img_path = os.path.join(
                self.save_path, 'output_{}'.format(self.opt.weights_epoch))

            if self.opt.pred_wout_blend:
                self.img_path += '_wout_blend'

            if not os.path.exists(self.img_path):
                os.makedirs(self.img_path)
            print("Saving prediction images to\n  ", self.img_path)
            print("Save frequency\n  ", self.opt.pred_frequency)

            # Get the colours from dataset.
            colors = [
                (label.trainId - self.task_low, label.color)
                for label in labels
                if label.trainId != 255 and label.trainId in self.train_ids
            ]
            colors.append((255, (0, 0, 0)))  # void class
            self.id_color = dict(colors)
            self.id_color_keys = [key for key in self.id_color.keys()]
            self.id_color_vals = [val for val in self.id_color.values()]

            # Ongoing index to name the outputs
            self.img_idx = 0

        # Set up probability output. Probabilities are only output if standalone validation is carried out!
        if not self.opt.validate and self.opt.save_probs_to_disk:
            # Output path
            self.logit_path = os.path.join(
                self.save_path,
                'probabilities_{}'.format(self.opt.weights_epoch))
            if not os.path.exists(self.logit_path):
                os.makedirs(self.logit_path)
            print("Saving probabilities to\n  ", self.logit_path)
            print("Save frequency\n  ", self.opt.probs_frequency)

            # Ongoing index to name the probability outputs
            self.probs_idx = 0

        print(
            "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

        # Save all options to disk and print them to stdout
        self._print_options()
        self._save_opts(len(test_dataset))

    def _get_labels_cityscapes(self, id=None):
        if id is None:
            id = self.opt.train_set

        if id == 1:
            labels = labels_cityscape_seg_train1.getlabels()
        elif id == 2:
            labels = labels_cityscape_seg_train2_eval.getlabels()
        elif id == 12:
            labels = labels_cityscape_seg_train2_eval.getlabels()
        elif id == 3:
            labels = labels_cityscape_seg_train3_eval.getlabels()
        elif id == 123:
            labels = labels_cityscape_seg_train3_eval.getlabels()

        return labels

    def _get_task_labels_cityscapes(self, id=None):
        if id is None:
            id = self.opt.task_to_val

        if id == 1:
            labels_task = labels_cityscape_seg_train1.getlabels()
        elif id == 2:
            labels_task = labels_cityscape_seg_train2.getlabels()
        elif id == 12:
            labels_task = labels_cityscape_seg_train2_eval.getlabels()
        elif id == 3:
            labels_task = labels_cityscape_seg_train3.getlabels()
        elif id == 123:
            labels_task = labels_cityscape_seg_train3_eval.getlabels()

        return labels_task

    def load_model(self):
        """Load model(s) from disk
        """
        base_path = os.path.split(self.log_path)[0]
        checkpoint_path = os.path.join(
            base_path, self.opt.load_model_name, 'models',
            'weights_{}'.format(self.opt.weights_epoch))
        assert os.path.isdir(checkpoint_path), \
            "Cannot find folder {}".format(checkpoint_path)
        print("loading model from folder {}".format(checkpoint_path))

        path = os.path.join(checkpoint_path, "{}.pth".format('model'))
        model_dict = self.model.state_dict()
        if self.opt.no_cuda:
            pretrained_dict = torch.load(path, map_location='cpu')
        else:
            pretrained_dict = torch.load(path)
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        self.model.load_state_dict(model_dict)

    def calculate_metrics(self, epoch=None):
        print("-> Computing predictions with input size {}x{}".format(
            self.opt.height, self.opt.width))
        print("-> Evaluating")

        for data in self.test_loader:
            with torch.no_grad():
                input_color = data[("color_aug", 0, 0)]
                gt_seg = data[('segmentation', 0, 0)][:, 0, :, :].numpy()
                input_color = {
                    ("color_aug", 0, 0): input_color.to(self.device)
                }

                output = self.model(input_color)

                pred_seg = output['segmentation_logits'].float()

                # Apply task reduction for argmax
                if self.opt.task_to_val != 0:
                    pred_seg = pred_seg[:, self.task_low:self.task_high, ...]
                    gt_seg -= self.task_low  # gt_seg trainIDs must be in range(0, self.num_classes_score) to map them with torch.argmax output
                    gt_seg[
                        gt_seg == 255 - self.
                        task_low] = 255  # maintaining the background trainID

                # Save probabilities to disk
                if not self.opt.validate and self.opt.save_probs_to_disk:
                    self._save_probs_to_disk(
                        F.softmax(pred_seg, dim=1).cpu().numpy())

                pred_seg = F.interpolate(pred_seg,
                                         gt_seg[0].shape,
                                         mode='nearest')

                # Select most probable class
                pred_seg = torch.argmax(pred_seg, dim=1)

                pred_seg = pred_seg.cpu().numpy()
                self.metric_model.update(gt_seg, pred_seg)

                # Save predictions to disk
                if not self.opt.validate and self.opt.save_pred_to_disk:
                    self._save_pred_to_disk(pred_seg, gt_seg)

        metrics = self.metric_model.get_scores()

        # Save metrics
        if self.opt.val_frequency != 0:
            # Local epoch will not be specified if the validation is carried out standalone.
            if not self.opt.validate and epoch is None:
                epoch = int(self.opt.weights_epoch)

            self._save_metrics(epoch, metrics)

        self.metric_model.reset()
        print("\n  " + ("{:>8} | " * 2).format("miou", "maccuracy"))
        print(("&{: 8.3f}  " *
               2).format(metrics['meaniou'], metrics['meanacc']) + "\\\\")
        print("\n-> Done!")

    def _save_metrics(self, epoch, metrics):
        ''' Save metrics (class-wise) to disk as HDF5 file.
        '''
        # If a single model is validated, the output file will carry its epoch number in its file name. If a learning
        # process is validated "on the go", the output filename will just be "validation.h5".
        if not self.opt.validate:
            filename = 'validation_{:d}.h5'.format(epoch)
        else:
            filename = 'validation.h5'
        save_path = os.path.join(self.save_path, filename)

        # When _save_metrics is invoked for the first time, the HDF file will be opened in "w" mode overwriting any
        # existing file. In case of another invocation, the file will be opened in "a" mode not overwriting any
        # existing file but appending the data.
        if not self.called:
            mode = 'w'
            self.called = True
        else:
            mode = 'a'

        # If a single model is validated, all datasets reside in the first layer of the HDF file. If a learning process
        # is validated "on the go", each validated model will have its own group named after the epoch of the model.
        with h5.File(save_path, mode) as f:
            if self.opt.validate:
                grp = f.create_group('epoch_{:d}'.format(epoch))
            else:
                grp = f

            # Write mean_IoU, mean_acc and mean prec to file / group
            dset = grp.create_dataset('mean_IoU', data=metrics['meaniou'])
            dset.attrs[
                'Description'] = 'See trainIDs for information on the classes'
            dset = grp.create_dataset('mean_recall', data=metrics['meanacc'])
            dset.attrs[
                'Description'] = 'See trainIDs for information on the classes'
            dset.attrs['AKA'] = 'Accuracy -> TP / (TP + FN)'
            dset = grp.create_dataset('mean_precision',
                                      data=metrics['meanprec'])
            dset.attrs[
                'Description'] = 'See trainIDs for information on the classes'
            dset.attrs['AKA'] = 'Precision -> TP / (TP + FP)'

            # If in 'w' mode, allocate memory for class_id dataset
            if mode == 'w':
                ids = np.zeros(shape=(len(metrics['iou'])), dtype=np.uint32)

            class_iou = np.zeros(shape=(len(metrics['iou'])), dtype=np.float64)
            class_acc = np.zeros(shape=(len(metrics['acc'])), dtype=np.float64)
            class_prec = np.zeros(shape=(len(metrics['prec'])),
                                  dtype=np.float64)

            # Disassemble the dictionary
            for key, i in zip(sorted(metrics['iou']),
                              range(len(metrics['iou']))):
                if mode == 'w':
                    ids[i] = self.train_ids[i]  # int(key)
                class_iou[i] = metrics['iou'][key]
                class_acc[i] = metrics['acc'][key]
                class_prec[i] = metrics['prec'][key]

            # Create class_id dataset only once in first layer of HDF5 file when in 'w' mode
            if mode == 'w':
                dset = f.create_dataset('trainIDs', data=ids)
                dset.attrs['Description'] = 'trainIDs of classes'
                dset = f.create_dataset('first_epoch_in_file',
                                        data=np.array([epoch
                                                       ]).astype(np.uint32))
                dset.attrs[
                    'Description'] = 'First epoch that has been saved in this file.'

            dset = grp.create_dataset('class_IoU', data=class_iou)
            dset.attrs[
                'Description'] = 'See trainIDs for information on the class order'
            dset = grp.create_dataset('class_recall', data=class_acc)
            dset.attrs[
                'Description'] = 'See trainIDs for information on the class order'
            dset.attrs['AKA'] = 'Accuracy -> TP / (TP + FN)'
            dset = grp.create_dataset('class_precision', data=class_prec)
            dset.attrs[
                'Description'] = 'See trainIDs for information on the class order'
            dset.attrs['AKA'] = 'Precision -> TP / (TP + FP)'

    def _save_pred_to_disk(self, pred, gt):
        ''' Save a correctly coloured image of the prediction (batch) to disk. Only every self.opt.pred_frequency-th
            prediction is saved to disk!
        '''
        for i in range(gt.shape[0]):
            if self.img_idx % self.opt.pred_frequency == 0:
                o_size = gt[i].shape  # original image shape

                single_pred = pred[i].flatten()
                single_gt = gt[i].flatten()

                # Copy voids from ground truth to prediction
                if not self.opt.pred_wout_blend:
                    single_pred[single_gt == 255] = 255

                # Convert to colour
                single_pred = self._convert_to_colour(single_pred, o_size)
                single_gt = self._convert_to_colour(single_gt, o_size)

                # Save predictions to disk using an ongoing index
                cv2.imwrite(
                    os.path.join(self.img_path,
                                 'pred_val_{}.png'.format(self.img_idx)),
                    single_pred)
                cv2.imwrite(
                    os.path.join(self.img_path,
                                 'gt_val_{}.png'.format(self.img_idx)),
                    single_gt)

            self.img_idx += 1

    def _convert_to_colour(self, img, o_size):
        ''' Replace trainIDs in prediction with colours from dict, reshape it afterwards to input dimensions and
            convert RGB to BGR to match openCV's colour system.
        '''
        sort_idx = np.argsort(self.id_color_keys)
        idx = np.searchsorted(self.id_color_keys, img, sorter=sort_idx)
        img = np.asarray(self.id_color_vals)[sort_idx][idx]
        img = img.astype(np.uint8)
        img = np.reshape(img, newshape=(o_size[0], o_size[1], 3))
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

        return img

    def _save_probs_to_disk(self, output):
        ''' Save the network output as numpy npy-file to disk. Only every self.opt.probs_frequency-th image is saved
            to disk!
        '''
        for i in range(output.shape[0]):
            if self.probs_idx % self.opt.probs_frequency == 0:
                np.save(
                    os.path.join(self.logit_path,
                                 'seg_logit_{}'.format(self.probs_idx)),
                    output[i])

            self.probs_idx += 1

    def _print_options(self):
        ''' Print validation options to stdout
        '''
        # Convert namespace to dictionary
        opts = vars(self.opt)

        # Get max key length for left justifying
        max_len = max([len(key) for key in opts.keys()])

        # Print options to stdout
        print(
            "+++++++++++++++++++++++++++ OPTIONS +++++++++++++++++++++++++++")
        for item in sorted(opts.items()):
            print(item[0].ljust(max_len), item[1])
        print(
            "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

    def _save_opts(self, n_eval):
        """Save options to disk so we know what we ran this experiment with
        """
        to_save = self.opt.__dict__.copy()
        to_save['n_eval'] = n_eval
        if self.opt.validate:
            filename = 'eval_opt.json'
        else:
            filename = 'eval_opt_{}.json'.format(self.opt.weights_epoch)

        with open(os.path.join(self.save_path, filename), 'w') as f:
            json.dump(to_save, f, indent=2)
Beispiel #5
0
    def __init__(self, options):
        self.opt = options
        self.device = torch.device("cpu" if self.opt.no_cuda else "cuda")

        # Assertions
        assert os.path.isfile(self.opt.image), "Invalid image!"
        self.opt.image.replace('/', os.sep)
        self.opt.image.replace('\\', os.sep)
        self.image_name = self.opt.image.split(os.sep)[-1]

        if self.opt.model_stage == 1:
            assert self.opt.task in {1}, "Invalid task!"
            assert not self.opt.with_weights, "Weights for stage 1 not available"
        elif self.opt.model_stage == 2:
            assert self.opt.task in {1, 2, 12}, "Invalid task!"
        elif self.opt.model_stage == 3:
            assert self.opt.task in {1, 2, 3, 12, 123}, "Invalid task!"

        # Model and task set-up
        self.num_classes_model = {1: 5, 2: 11, 3: 19}[self.opt.model_stage]
        self.task_low, self.task_high = {1: (0, 5), 2: (5, 11), 3: (11, 19), 12: (0, 11), 123: (0, 19)}[self.opt.task]

        # Create a conventional ERFNet
        self.model = ERFNet(self.num_classes_model, self.opt)
        self._load_model()
        self.model.to(self.device)
        self.model.eval()

        # Ground truth
        self.metrics = False
        if self.opt.ground_truth:
            assert os.path.isfile(self.opt.ground_truth), "Invalid ground truth!"
            self.metrics = True
            self.num_classes_score = self.task_high - self.task_low
            self.metric_model = SegmentationRunningScore(self.num_classes_score)

        # Output directory
        if self.opt.output_path:
            if not os.path.isdir(self.opt.output_path):
                os.makedirs(self.opt.output_path)
        else:
            self.opt.output_path = os.path.join(self.opt.image.split(os.sep)[:-1])
        image_extension_idx = self.image_name.rfind('.')
        segmentation_name = self.image_name[:image_extension_idx] + \
                            "_seg_stage_{}_task_{}".format(self.opt.model_stage, self.opt.task) + \
                            self.image_name[image_extension_idx:]
        self.output_image = os.path.join(self.opt.output_path, segmentation_name)
        ground_truth_name = self.image_name[:image_extension_idx] + \
                            "_gt_stage_{}_task_{}".format(self.opt.model_stage, self.opt.task) + \
                            self.image_name[image_extension_idx:]
        self.output_gt = os.path.join(self.opt.output_path, ground_truth_name)

        # stdout output
        print("++++++++++++++++++++++ INIT DEMO ++++++++++++++++++++++++")
        print("Image:\t {}".format(self.opt.image))
        print("GT:\t {}".format(self.opt.ground_truth))
        print("Output:\t {}".format(self.opt.output_path))
        print("Stage:\t {}".format(self.opt.model_stage))
        print("Weights: {}".format(self.opt.with_weights))
        print("Task:\t {}".format(self.opt.task))
        print("!!! MIND THAT THE MODELS WERE TRAINED USING AN IMAGE RESOLUTION OF 1024x512px !!!")

        # Class colours
        labels = labels_cityscape_seg_train3_eval.getlabels()
        colors = [(label.trainId - self.task_low, label.color) for label in labels if
                      label.trainId != 255 and label.trainId in range(0, 19)]
        colors.append((255, (0, 0, 0)))  # void class
        self.id_color = dict(colors)
        self.id_color_keys = [key for key in self.id_color.keys()]
        self.id_color_vals = [val for val in self.id_color.values()]
Beispiel #6
0
class Demo(object):
    def __init__(self, options):
        self.opt = options
        self.device = torch.device("cpu" if self.opt.no_cuda else "cuda")

        # Assertions
        assert os.path.isfile(self.opt.image), "Invalid image!"
        self.opt.image.replace('/', os.sep)
        self.opt.image.replace('\\', os.sep)
        self.image_name = self.opt.image.split(os.sep)[-1]

        if self.opt.model_stage == 1:
            assert self.opt.task in {1}, "Invalid task!"
            assert not self.opt.with_weights, "Weights for stage 1 not available"
        elif self.opt.model_stage == 2:
            assert self.opt.task in {1, 2, 12}, "Invalid task!"
        elif self.opt.model_stage == 3:
            assert self.opt.task in {1, 2, 3, 12, 123}, "Invalid task!"

        # Model and task set-up
        self.num_classes_model = {1: 5, 2: 11, 3: 19}[self.opt.model_stage]
        self.task_low, self.task_high = {1: (0, 5), 2: (5, 11), 3: (11, 19), 12: (0, 11), 123: (0, 19)}[self.opt.task]

        # Create a conventional ERFNet
        self.model = ERFNet(self.num_classes_model, self.opt)
        self._load_model()
        self.model.to(self.device)
        self.model.eval()

        # Ground truth
        self.metrics = False
        if self.opt.ground_truth:
            assert os.path.isfile(self.opt.ground_truth), "Invalid ground truth!"
            self.metrics = True
            self.num_classes_score = self.task_high - self.task_low
            self.metric_model = SegmentationRunningScore(self.num_classes_score)

        # Output directory
        if self.opt.output_path:
            if not os.path.isdir(self.opt.output_path):
                os.makedirs(self.opt.output_path)
        else:
            self.opt.output_path = os.path.join(self.opt.image.split(os.sep)[:-1])
        image_extension_idx = self.image_name.rfind('.')
        segmentation_name = self.image_name[:image_extension_idx] + \
                            "_seg_stage_{}_task_{}".format(self.opt.model_stage, self.opt.task) + \
                            self.image_name[image_extension_idx:]
        self.output_image = os.path.join(self.opt.output_path, segmentation_name)
        ground_truth_name = self.image_name[:image_extension_idx] + \
                            "_gt_stage_{}_task_{}".format(self.opt.model_stage, self.opt.task) + \
                            self.image_name[image_extension_idx:]
        self.output_gt = os.path.join(self.opt.output_path, ground_truth_name)

        # stdout output
        print("++++++++++++++++++++++ INIT DEMO ++++++++++++++++++++++++")
        print("Image:\t {}".format(self.opt.image))
        print("GT:\t {}".format(self.opt.ground_truth))
        print("Output:\t {}".format(self.opt.output_path))
        print("Stage:\t {}".format(self.opt.model_stage))
        print("Weights: {}".format(self.opt.with_weights))
        print("Task:\t {}".format(self.opt.task))
        print("!!! MIND THAT THE MODELS WERE TRAINED USING AN IMAGE RESOLUTION OF 1024x512px !!!")

        # Class colours
        labels = labels_cityscape_seg_train3_eval.getlabels()
        colors = [(label.trainId - self.task_low, label.color) for label in labels if
                      label.trainId != 255 and label.trainId in range(0, 19)]
        colors.append((255, (0, 0, 0)))  # void class
        self.id_color = dict(colors)
        self.id_color_keys = [key for key in self.id_color.keys()]
        self.id_color_vals = [val for val in self.id_color.values()]


    def _load_model(self):
        """Load model from disk
        """
        path = self.opt.checkpoint_path
        # checkpoint_path = os.path.join("models", "stage_{}".format(self.opt.model_stage))
        #assert os.path.isdir(checkpoint_path), \
        #    "Cannot find folder {}".format(checkpoint_path)

        # path = os.path.join(checkpoint_path, "{}.pth".format("with_weights" if self.opt.with_weights else "wout_weights"))
        model_dict = self.model.state_dict()
        if self.opt.no_cuda:
            pretrained_dict = torch.load(path, map_location='cpu')
        else:
            pretrained_dict = torch.load(path)
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        model_dict.update(pretrained_dict)
        self.model.load_state_dict(model_dict)

    def process_image(self):
        # Required image transformations
        resize_interp = transforms.Resize((512, 1024), interpolation=pil.BILINEAR)
        transformer = transforms.ToTensor()
        normalize = transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

        # Load Image
        image = cv2.imread(self.opt.image)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = pil.fromarray(image)
        native_image_size = image.size

        # Transform image
        image = resize_interp(image)
        image = transformer(image)
        image = normalize(image).unsqueeze(0).to(self.device)

        # Process image
        input_rgb = {("color_aug", 0, 0): image}
        output = self.model(input_rgb)

        # Process network output
        pred_seg = output['segmentation_logits'].float()
        pred_seg = pred_seg[:, self.task_low:self.task_high, ...]
        pred_seg = F.interpolate(pred_seg, (native_image_size[1], native_image_size[0]), mode='nearest')
        pred_seg = torch.argmax(pred_seg, dim=1)
        pred_seg = pred_seg.cpu().numpy()

        # Process ground truth
        gt = None
        if self.opt.ground_truth:
            gt = cv2.imread(self.opt.ground_truth, 0)
            gt[gt < self.task_low] = 255
            gt[gt >= self.task_high] = 255
            gt -= self.task_low
            gt[gt == 255 - self.task_low] = 255
            gt = np.expand_dims(gt, 0)

            self.metric_model.update(gt, pred_seg)
            metrics = self.metric_model.get_scores()
            self._save_metrics(metrics)
            print("\n  " + ("{:>8} | " * 2).format("miou", "maccuracy"))
            print(("&{: 8.3f}  " * 2).format(metrics['meaniou'], metrics['meanacc']) + "\\\\")

        # Save prediction to disk
        self._save_pred_to_disk(pred_seg, gt)

        print("\n-> Done!")


    def _save_metrics(self, metrics):
        ''' Save metrics (class-wise) to disk as HDF5 file.
        '''
        save_path = os.path.join(self.opt.output_path, "demo.h5")

        with h5.File(save_path, 'w') as f:
            grp = f

            # Write mean_IoU, mean_acc and mean prec to file / group
            dset = grp.create_dataset('mean_IoU', data=metrics['meaniou'])
            dset.attrs['Description'] = 'See trainIDs for information on the classes'
            dset = grp.create_dataset('mean_recall', data=metrics['meanacc'])
            dset.attrs['Description'] = 'See trainIDs for information on the classes'
            dset.attrs['AKA'] = 'Accuracy -> TP / (TP + FN)'
            dset = grp.create_dataset('mean_precision', data=metrics['meanprec'])
            dset.attrs['Description'] = 'See trainIDs for information on the classes'
            dset.attrs['AKA'] = 'Precision -> TP / (TP + FP)'

            ids = np.zeros(shape=(len(metrics['iou'])), dtype=np.uint32)

            class_iou = np.zeros(shape=(len(metrics['iou'])), dtype=np.float64)
            class_acc = np.zeros(shape=(len(metrics['acc'])), dtype=np.float64)
            class_prec = np.zeros(shape=(len(metrics['prec'])), dtype=np.float64)

            # Disassemble the dictionary
            for key, i in zip(sorted(metrics['iou']), range(len(metrics['iou']))):
                class_iou[i] = metrics['iou'][key]
                class_acc[i] = metrics['acc'][key]
                class_prec[i] = metrics['prec'][key]

            # Create class_id dataset only once in first layer of HDF5 file when in 'w' mode
            dset = f.create_dataset('trainIDs', data=ids)
            dset.attrs['Description'] = 'trainIDs of classes'

            dset = grp.create_dataset('class_IoU', data=class_iou)
            dset.attrs['Description'] = 'See trainIDs for information on the class order'
            dset = grp.create_dataset('class_recall', data=class_acc)
            dset.attrs['Description'] = 'See trainIDs for information on the class order'
            dset.attrs['AKA'] = 'Accuracy -> TP / (TP + FN)'
            dset = grp.create_dataset('class_precision', data=class_prec)
            dset.attrs['Description'] = 'See trainIDs for information on the class order'
            dset.attrs['AKA'] = 'Precision -> TP / (TP + FP)'

    def _save_pred_to_disk(self, pred, gt=None):
        ''' Save a correctly coloured image of the prediction (batch) to disk.
        '''
        pred = pred[0]
        o_size = pred.shape
        single_pred = pred.flatten()

        if gt is not None:
            single_gt = gt[0].flatten()
            single_pred[single_gt == 255] = 255
            single_gt = self._convert_to_colour(single_gt, o_size)
            cv2.imwrite(self.output_gt, single_gt)

        single_pred = self._convert_to_colour(single_pred, o_size)
        cv2.imwrite(self.output_image, single_pred)


    def _convert_to_colour(self, img, o_size):
        ''' Replace trainIDs in prediction with colours from dict, reshape it afterwards to input dimensions and
            convert RGB to BGR to match openCV's colour system.
        '''
        sort_idx = np.argsort(self.id_color_keys)
        idx = np.searchsorted(self.id_color_keys, img, sorter=sort_idx)
        img = np.asarray(self.id_color_vals)[sort_idx][idx]
        img = img.astype(np.uint8)
        img = np.reshape(img, newshape=(o_size[0], o_size[1], 3))
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

        return img
    def __init__(self, options):

        print(" -> Executing script", os.path.basename(__file__))

        self.opt = options
        self.device = torch.device("cpu" if self.opt.no_cuda else "cuda")

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           LABELS AND CITIES
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        assert self.opt.train_set in {123, 1}, "Invalid train_set!"
        keys_to_load = ['color', 'segmentation']

        # Labels
        if self.opt.train_set == 1:
            labels = labels_cityscape_seg_train1.getlabels()
        else:
            labels = labels_cityscape_seg_train3_eval.getlabels()

        # Train IDs
        self.train_ids = set([labels[i].trainId for i in range(len(labels))])
        self.train_ids.remove(255)

        self.num_classes = len(self.train_ids)

        # Apply city filter
        folders_to_train = CitySet.get_city_set(0)
        if self.opt.city:
            folders_to_train = CitySet.get_city_set(self.opt.train_set)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           DATASET DEFINITIONS
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # Data augmentation
        train_data_transforms = [
            mytransforms.RandomHorizontalFlip(),
            mytransforms.CreateScaledImage(),
            mytransforms.Resize((self.opt.height, self.opt.width),
                                image_types=keys_to_load),
            mytransforms.RandomRescale(1.5),
            mytransforms.RandomCrop(
                (self.opt.crop_height, self.opt.crop_width)),
            mytransforms.ConvertSegmentation(),
            mytransforms.CreateColoraug(new_element=True,
                                        scales=self.opt.scales),
            mytransforms.ColorJitter(brightness=0.2,
                                     contrast=0.2,
                                     saturation=0.2,
                                     hue=0.1,
                                     gamma=0.0),
            mytransforms.RemoveOriginals(),
            mytransforms.ToTensor(),
            mytransforms.NormalizeZeroMean(),
        ]

        train_dataset = CityscapesDataset(
            dataset="cityscapes",
            trainvaltest_split='train',
            video_mode='mono',
            stereo_mode='mono',
            scales=self.opt.scales,
            labels_mode='fromid',
            labels=labels,
            keys_to_load=keys_to_load,
            data_transforms=train_data_transforms,
            video_frames=self.opt.video_frames,
            folders_to_load=folders_to_train,
        )

        self.train_loader = DataLoader(dataset=train_dataset,
                                       batch_size=self.opt.batch_size,
                                       shuffle=True,
                                       num_workers=self.opt.num_workers,
                                       pin_memory=True,
                                       drop_last=True)

        val_data_transforms = [
            mytransforms.CreateScaledImage(),
            mytransforms.Resize((self.opt.height, self.opt.width),
                                image_types=keys_to_load),
            mytransforms.ConvertSegmentation(),
            mytransforms.CreateColoraug(new_element=True,
                                        scales=self.opt.scales),
            mytransforms.RemoveOriginals(),
            mytransforms.ToTensor(),
            mytransforms.NormalizeZeroMean(),
        ]

        val_dataset = CityscapesDataset(
            dataset=self.opt.dataset,
            trainvaltest_split="train",
            video_mode='mono',
            stereo_mode='mono',
            scales=self.opt.scales,
            labels_mode='fromid',
            labels=labels,
            keys_to_load=keys_to_load,
            data_transforms=val_data_transforms,
            video_frames=self.opt.video_frames,
            folders_to_load=CitySet.get_city_set(-1))

        self.val_loader = DataLoader(dataset=val_dataset,
                                     batch_size=self.opt.batch_size,
                                     shuffle=False,
                                     num_workers=self.opt.num_workers,
                                     pin_memory=True,
                                     drop_last=True)

        self.val_iter = iter(self.val_loader)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           LOGGING OPTIONS
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        print(
            "++++++++++++++++++++++ INIT TRAINING ++++++++++++++++++++++++++")
        print("Using dataset:\n  ", self.opt.dataset, "with split",
              self.opt.dataset_split)
        print(
            "There are {:d} training items and {:d} validation items\n".format(
                len(train_dataset), len(val_dataset)))

        path_getter = GetPath()
        log_path = path_getter.get_checkpoint_path()
        self.log_path = os.path.join(log_path, 'erfnet', self.opt.model_name)

        self.writers = {}
        for mode in ["train", "validation"]:
            self.writers[mode] = SummaryWriter(
                os.path.join(self.log_path, mode))

        # Copy this file to log dir
        shutil.copy2(__file__, self.log_path)

        print("Training model named:\n  ", self.opt.model_name)
        print("Models and tensorboard events files are saved to:\n  ",
              self.log_path)
        print("Training is using:\n  ", self.device)
        print("Training takes place on train set:\n  ", self.opt.train_set)
        print(
            "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           MODEL DEFINITION
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # Instantiate model
        self.model = ERFNet(self.num_classes, self.opt)
        self.model.to(self.device)
        self.parameters_to_train = self.model.parameters()

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           OPTIMIZER SET-UP
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        self.model_optimizer = optim.Adam(params=self.parameters_to_train,
                                          lr=self.opt.learning_rate,
                                          weight_decay=self.opt.weight_decay)
        lambda1 = lambda epoch: pow((1 -
                                     ((epoch - 1) / self.opt.num_epochs)), 0.9)
        self.model_lr_scheduler = optim.lr_scheduler.LambdaLR(
            self.model_optimizer, lr_lambda=lambda1)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           LOSSES
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        self.crossentropy = CrossEntropyLoss(ignore_background=True,
                                             device=self.device)
        self.crossentropy.to(self.device)

        self.metric_model = SegmentationRunningScore(self.num_classes)

        # Save all options to disk and print them to stdout
        self.save_opts(len(train_dataset), len(val_dataset))
        self._print_options()

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           EVALUATOR DEFINITION
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        if self.opt.validate:
            self.evaluator = Evaluator(self.opt, self.model)
class Trainer:
    def __init__(self, options):

        print(" -> Executing script", os.path.basename(__file__))

        self.opt = options
        self.device = torch.device("cpu" if self.opt.no_cuda else "cuda")

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           LABELS AND CITIES
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        assert self.opt.train_set in {123, 1}, "Invalid train_set!"
        keys_to_load = ['color', 'segmentation']

        # Labels
        if self.opt.train_set == 1:
            labels = labels_cityscape_seg_train1.getlabels()
        else:
            labels = labels_cityscape_seg_train3_eval.getlabels()

        # Train IDs
        self.train_ids = set([labels[i].trainId for i in range(len(labels))])
        self.train_ids.remove(255)

        self.num_classes = len(self.train_ids)

        # Apply city filter
        folders_to_train = CitySet.get_city_set(0)
        if self.opt.city:
            folders_to_train = CitySet.get_city_set(self.opt.train_set)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           DATASET DEFINITIONS
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # Data augmentation
        train_data_transforms = [
            mytransforms.RandomHorizontalFlip(),
            mytransforms.CreateScaledImage(),
            mytransforms.Resize((self.opt.height, self.opt.width),
                                image_types=keys_to_load),
            mytransforms.RandomRescale(1.5),
            mytransforms.RandomCrop(
                (self.opt.crop_height, self.opt.crop_width)),
            mytransforms.ConvertSegmentation(),
            mytransforms.CreateColoraug(new_element=True,
                                        scales=self.opt.scales),
            mytransforms.ColorJitter(brightness=0.2,
                                     contrast=0.2,
                                     saturation=0.2,
                                     hue=0.1,
                                     gamma=0.0),
            mytransforms.RemoveOriginals(),
            mytransforms.ToTensor(),
            mytransforms.NormalizeZeroMean(),
        ]

        train_dataset = CityscapesDataset(
            dataset="cityscapes",
            trainvaltest_split='train',
            video_mode='mono',
            stereo_mode='mono',
            scales=self.opt.scales,
            labels_mode='fromid',
            labels=labels,
            keys_to_load=keys_to_load,
            data_transforms=train_data_transforms,
            video_frames=self.opt.video_frames,
            folders_to_load=folders_to_train,
        )

        self.train_loader = DataLoader(dataset=train_dataset,
                                       batch_size=self.opt.batch_size,
                                       shuffle=True,
                                       num_workers=self.opt.num_workers,
                                       pin_memory=True,
                                       drop_last=True)

        val_data_transforms = [
            mytransforms.CreateScaledImage(),
            mytransforms.Resize((self.opt.height, self.opt.width),
                                image_types=keys_to_load),
            mytransforms.ConvertSegmentation(),
            mytransforms.CreateColoraug(new_element=True,
                                        scales=self.opt.scales),
            mytransforms.RemoveOriginals(),
            mytransforms.ToTensor(),
            mytransforms.NormalizeZeroMean(),
        ]

        val_dataset = CityscapesDataset(
            dataset=self.opt.dataset,
            trainvaltest_split="train",
            video_mode='mono',
            stereo_mode='mono',
            scales=self.opt.scales,
            labels_mode='fromid',
            labels=labels,
            keys_to_load=keys_to_load,
            data_transforms=val_data_transforms,
            video_frames=self.opt.video_frames,
            folders_to_load=CitySet.get_city_set(-1))

        self.val_loader = DataLoader(dataset=val_dataset,
                                     batch_size=self.opt.batch_size,
                                     shuffle=False,
                                     num_workers=self.opt.num_workers,
                                     pin_memory=True,
                                     drop_last=True)

        self.val_iter = iter(self.val_loader)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           LOGGING OPTIONS
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        print(
            "++++++++++++++++++++++ INIT TRAINING ++++++++++++++++++++++++++")
        print("Using dataset:\n  ", self.opt.dataset, "with split",
              self.opt.dataset_split)
        print(
            "There are {:d} training items and {:d} validation items\n".format(
                len(train_dataset), len(val_dataset)))

        path_getter = GetPath()
        log_path = path_getter.get_checkpoint_path()
        self.log_path = os.path.join(log_path, 'erfnet', self.opt.model_name)

        self.writers = {}
        for mode in ["train", "validation"]:
            self.writers[mode] = SummaryWriter(
                os.path.join(self.log_path, mode))

        # Copy this file to log dir
        shutil.copy2(__file__, self.log_path)

        print("Training model named:\n  ", self.opt.model_name)
        print("Models and tensorboard events files are saved to:\n  ",
              self.log_path)
        print("Training is using:\n  ", self.device)
        print("Training takes place on train set:\n  ", self.opt.train_set)
        print(
            "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           MODEL DEFINITION
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # Instantiate model
        self.model = ERFNet(self.num_classes, self.opt)
        self.model.to(self.device)
        self.parameters_to_train = self.model.parameters()

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           OPTIMIZER SET-UP
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        self.model_optimizer = optim.Adam(params=self.parameters_to_train,
                                          lr=self.opt.learning_rate,
                                          weight_decay=self.opt.weight_decay)
        lambda1 = lambda epoch: pow((1 -
                                     ((epoch - 1) / self.opt.num_epochs)), 0.9)
        self.model_lr_scheduler = optim.lr_scheduler.LambdaLR(
            self.model_optimizer, lr_lambda=lambda1)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           LOSSES
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        self.crossentropy = CrossEntropyLoss(ignore_background=True,
                                             device=self.device)
        self.crossentropy.to(self.device)

        self.metric_model = SegmentationRunningScore(self.num_classes)

        # Save all options to disk and print them to stdout
        self.save_opts(len(train_dataset), len(val_dataset))
        self._print_options()

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           EVALUATOR DEFINITION
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        if self.opt.validate:
            self.evaluator = Evaluator(self.opt, self.model)

    def set_train(self):
        """Convert all models to training mode
        """
        self.model.train()

    def set_eval(self):
        """Convert all models to testing/evaluation mode
        """
        self.model.eval()

    def train(self):
        """Run the entire training pipeline
        """
        self.epoch = 0
        self.step = 0
        self.start_time = time.time()

        for self.epoch in range(self.opt.num_epochs):
            self.run_epoch()
            if (self.epoch + 1) % self.opt.save_frequency == 0:
                self.save_model()
            if self.opt.validate and (self.epoch +
                                      1) % self.opt.val_frequency == 0:
                self.run_eval()

    def run_epoch(self):
        """Run a single epoch of training and validation
        """
        print("Training")
        self.set_train()

        for batch_idx, inputs in enumerate(self.train_loader):

            before_op_time = time.time()

            outputs, losses = self.process_batch(inputs)

            self.model_optimizer.zero_grad()
            losses["loss"].backward()
            self.model_optimizer.step()
            duration = time.time() - before_op_time

            # log less frequently after the first 2000 steps to save time & disk space
            early_phase = batch_idx % self.opt.log_frequency == 0 and self.step < 2000
            late_phase = self.step % 2000 == 0

            if early_phase or late_phase:
                if ('segmentation', 0, 0) in inputs.keys():
                    metrics = self.compute_segmentation_losses(inputs, outputs)
                    self.log_time(batch_idx, duration,
                                  losses["loss"].cpu().data,
                                  metrics["meaniou"], metrics["meanacc"])
                else:
                    self.log_time(batch_idx, duration,
                                  losses["loss"].cpu().data, 0, 0)
                    metrics = {}
                self.log("train", losses, metrics)
                self.val()
            self.step += 1

        self.model_lr_scheduler.step()

    def run_eval(self):
        print("Validating on full validation set")
        self.set_eval()

        self.evaluator.calculate_metrics(self.epoch)

    def val(self):
        """Validate the model on a single minibatch
        """
        self.set_eval()
        try:
            inputs_val = self.val_iter.next()
        except StopIteration:
            self.val_iter = iter(self.val_loader)
            inputs_val = self.val_iter.next()

        with torch.no_grad():
            outputs_val, losses_val = self.process_batch(inputs_val)

            if ('segmentation', 0, 0) in inputs_val:
                metrics_val = self.compute_segmentation_losses(
                    inputs_val, outputs_val)
            else:
                metrics_val = {}

            self.log("validation", losses_val, metrics_val)

        self.set_train()

    def process_batch(self, inputs):
        """Pass a minibatch through the network and generate images and losses
        """
        for key, ipt in inputs.items():
            inputs[key] = ipt.to(self.device)
        outputs = self.model(inputs)
        losses = self.compute_losses(inputs, outputs)

        return outputs, losses

    def compute_losses(self, inputs, outputs):
        """Compute the reprojection and smoothness losses for a minibatch
        """
        losses = {}
        preds = F.log_softmax(outputs['segmentation_logits'].float(), dim=1)
        targets = inputs[('segmentation', 0, 0)][:, 0, :, :].long()
        cross_loss = self.crossentropy(preds, targets)
        losses["loss"] = cross_loss

        return losses

    def compute_segmentation_losses(self, inputs, outputs):
        """Compute the loss metrics based on the current prediction
        """
        label_true = np.array(inputs[('segmentation', 0, 0)].cpu())[:, 0, :, :]
        label_pred = np.array(outputs['segmentation'].detach().cpu())
        self.metric_model.update(label_true, label_pred)
        metrics = self.metric_model.get_scores()
        self.metric_model.reset()
        return metrics

    def log_time(self, batch_idx, duration, loss, miou, acc):
        """Print a logging statement to the terminal
        """
        samples_per_sec = self.opt.batch_size / duration
        print_string = "epoch {:>3} | batch {:>6} | examples/s: {:5.1f}" + \
                       " | loss: {:.5f}| meaniou: {:.5f}| meanacc: {:.5f}"
        print(
            print_string.format(self.epoch, batch_idx, samples_per_sec, loss,
                                miou, acc))

    def log(self, mode, losses, metrics):
        """Write an event to the tensorboard events file
        """
        writer = self.writers[mode]
        for l, v in losses.items():
            writer.add_scalar("{}".format(l), v, self.step)
        for l, v in metrics.items():
            if l in {'iou', 'acc', 'prec'}:
                continue
            writer.add_scalar("{}".format(l), v, self.step)

    def save_opts(self, n_train, n_eval):
        """Save options to disk so we know what we ran this experiment with
        """
        models_dir = os.path.join(self.log_path, "models")
        if not os.path.exists(models_dir):
            os.makedirs(models_dir)
        to_save = self.opt.__dict__.copy()
        to_save['n_train'] = n_train
        to_save['n_eval'] = n_eval

        with open(os.path.join(models_dir, 'opt.json'), 'w') as f:
            json.dump(to_save, f, indent=2)

    def save_model(self):
        """Save model weights to disk
        """
        save_folder = os.path.join(self.log_path, "models",
                                   "weights_{}".format(self.epoch))
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)

        save_path = os.path.join(save_folder, "{}.pth".format("model"))
        to_save = self.model.state_dict()
        torch.save(to_save, save_path)

        save_path = os.path.join(save_folder, "{}.pth".format("optim"))
        torch.save(self.model_optimizer.state_dict(), save_path)

    def load_model(self, adam=True):
        """Load model(s) from disk
        :param adam: whether to load the Adam state too
        """
        base_path = os.path.split(self.log_path)[0]
        checkpoint_path = os.path.join(
            base_path, self.opt.load_model_name, 'models',
            'weights_{}'.format(self.opt.weights_epoch))
        assert os.path.isdir(checkpoint_path), \
            "Cannot find folder {}".format(checkpoint_path)
        print("loading model from folder {}".format(checkpoint_path))

        path = os.path.join(checkpoint_path, "{}.pth".format('model'))
        model_dict = self.model.state_dict()
        pretrained_dict = torch.load(path)
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        self.model.load_state_dict(model_dict)

        if adam:
            # loading adam state
            optimizer_load_path = os.path.join(checkpoint_path,
                                               "{}.pth".format("optim"))
            if os.path.isfile(optimizer_load_path):
                print("Loading Adam weights")
                optimizer_dict = torch.load(optimizer_load_path)
                self.model_optimizer.load_state_dict(optimizer_dict)
            else:
                print(
                    "Cannot find Adam weights so Adam is randomly initialized")

    def _print_options(self):
        """Print training options to stdout so that they appear in the SLURM log
        """
        # Convert namespace to dictionary
        opts = vars(self.opt)

        # Get max key length for left justifying
        max_len = max([len(key) for key in opts.keys()])

        # Print options to stdout
        print(
            "+++++++++++++++++++++++++++ OPTIONS +++++++++++++++++++++++++++")
        for item in sorted(opts.items()):
            print(item[0].ljust(max_len), item[1])
        print(
            "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
Beispiel #9
0
def train(train_loader, val_loader, circ_S):
    print("\nTraining...\n")

    model = ERFNet(1).to(device).double()
    #criterion = nn.MSELoss()
    criterion = ReconsLoss(circ_S)

    optimizer = optim.Adam(model.parameters(),
                           lr=args.learning_rate,
                           weight_decay=args.weight_decay)

    # Learning rate decay scheduler
    lr_updater = lr_scheduler.StepLR(optimizer, args.lr_decay_epochs,
                                     args.lr_decay)

    # Optionally resume from a checkpoint
    if args.resume:
        model, optimizer, start_epoch, best_loss, best_snr = utils.load_checkpoint(
            model, optimizer, args.save_dir, args.name)
        print("Resuming from model: Start epoch = {0} "
              "| Best mean loss = {1:.4f} Best mean snr = {1:.4f}".format(
                  start_epoch, best_snr))
    else:
        start_epoch = 0
        best_loss = 0
        best_snr = 0

    if args.visdom:
        vis = visdom.Visdom()

        loss_win = vis.line(X=np.column_stack(
            (np.array(start_epoch), np.array(start_epoch))),
                            Y=np.column_stack(
                                (np.array(best_loss), np.array(best_loss))),
                            opts=dict(legend=['train', 'test'],
                                      xlabel='epoch',
                                      ylabel='loss',
                                      title='Loss'))
        snr_win = vis.line(X=np.column_stack(
            (np.array(start_epoch), np.array(start_epoch))),
                           Y=np.column_stack((np.array(0.), np.array(0.))),
                           opts=dict(legend=['train', 'test'],
                                     xlabel='epoch',
                                     ylabel='snr',
                                     title='SNR'))

    # Start Training
    print()
    train = Train(model, train_loader, optimizer, criterion, device)
    val = Test(model, val_loader, criterion, device)
    for epoch in range(start_epoch, args.epochs):
        print(">>>> [Epoch: {0:d}] Training".format(epoch))

        epoch_loss, epoch_snr = train.run_epoch(lr_updater, args.print_step)
        lr_updater.step()
        print(
            ">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} Avg. snr: {2:.4f} Lr: {3:f}"
            .format(epoch, epoch_loss, epoch_snr,
                    lr_updater.get_lr()[0]))

        if (epoch + 1) % 1 == 0 or epoch + 1 == args.epochs:
            print(">>>> [Epoch: {0:d}] Validation".format(epoch))

            loss, snr = val.run_epoch(args.print_step)

            print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} Avg. snr: {2:.4f}".
                  format(epoch, loss, snr))

            # Save the model if it's the best thus far
            if snr > best_snr:
                print("\nBest model thus far. Saving...\n")
                best_snr = snr
                utils.save_checkpoint(model, optimizer, epoch + 1, best_loss,
                                      best_snr, args)
        if args.visdom:
            vis.line(X=np.column_stack((np.array(epoch), np.array(epoch))),
                     Y=np.column_stack((np.array(epoch_loss), np.array(loss))),
                     win=loss_win,
                     update='append')

            vis.line(X=np.column_stack((np.array(epoch), np.array(epoch))),
                     Y=np.column_stack((np.array(epoch_snr), np.array(snr))),
                     win=snr_win,
                     update='append')

    return model
Beispiel #10
0
    # Fail fast if the dataset directory doesn't exist
    assert os.path.isdir(
        args.dataset_dir), "The directory \"{0}\" doesn't exist.".format(
            args.dataset_dir)

    # Fail fast if the saving directory doesn't exist
    assert os.path.isdir(
        args.save_dir), "The directory \"{0}\" doesn't exist.".format(
            args.save_dir)

    train_loader, val_loader, test_loader = load_dataset(dataset)
    circ_S, _ = hadamard_s(args.matrix_size)

    if args.mode.lower() in {'train', 'full'}:
        model = train(train_loader, val_loader, circ_S)

    if args.mode.lower() in {'test', 'full'}:
        if args.mode.lower() == 'test':
            # Intialize a new ERFNet model
            model = ERFNet(1).to(device).double()

        # Initialize a optimizer just so we can retrieve the model from the
        # checkpoint
        optimizer = optim.Adam(model.parameters())

        # Load the previoulsy saved model state to the ERFNet model
        model = utils.load_checkpoint(model, optimizer, args.save_dir,
                                      args.name)[0]

        test(model, test_loader, circ_S)
Beispiel #11
0
# im = Image.open('images/test.jpg')
im = Image.open('/aimldl-dat/samples/lanenet/4.jpg')
print(im)
# ipynb visualization
# get_ipython().run_line_magic('matplotlib', 'inline')
imgplot = imshow(np.asarray(im))
show()

im = im.resize((WIDTH, HEIGHT))

im_tensor = ToTensor()(im)
im_tensor = im_tensor.unsqueeze(0)

# We also need to load the weights of the CNNs. We simply load it using pytorch methods.
# Creating CNNs and loading pretrained models
segmentation_network = ERFNet(NUM_CLASSES_SEGMENTATION)
classification_network = LCNet(NUM_CLASSES_CLASSIFICATION, DESCRIPTOR_SIZE,
                               DESCRIPTOR_SIZE)

segmentation_network.load_state_dict(
    torch.load('pretrained/erfnet_tusimple.pth', map_location=map_location))
model_path = 'pretrained/classification_{}_{}class.pth'.format(
    DESCRIPTOR_SIZE, NUM_CLASSES_CLASSIFICATION)
classification_network.load_state_dict(
    torch.load(model_path, map_location=map_location))

segmentation_network = segmentation_network.eval()
classification_network = classification_network.eval()

if torch.cuda.is_available():
    segmentation_network = segmentation_network.cuda()
Beispiel #12
0
def video():
    print('testing from video')
    cameraWidth = 1920
    cameraHeight = 1080
    cameraMatrix = np.matrix([[1.3878727764994030e+03, 0, cameraWidth / 2],
                              [0, 1.7987055172413220e+03, cameraHeight / 2],
                              [0, 0, 1]])

    distCoeffs = np.matrix([
        -5.8881725390917083e-01, 5.8472404395779809e-01,
        -2.8299599929891900e-01, 0
    ])

    vidcap = cv2.VideoCapture('test_content/massachusetts.mp4')
    success = True
    i = 0
    while success:
        success, img = vidcap.read()
        if i % 1000 == 0:
            print("frame: ", i)
            if args.rmdistort:
                P = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
                    cameraMatrix, distCoeffs, (cameraWidth, cameraHeight),
                    None)
                map1, map2 = cv2.fisheye.initUndistortRectifyMap(
                    cameraMatrix, distCoeffs, np.eye(3), P, (1920, 1080),
                    cv2.CV_16SC2)
                img = cv2.remap(img, map1, map2, cv2.INTER_LINEAR)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = Image.fromarray(img)
            # img = img.convert('RGB')
            # cv2.imshow('',img)
            # cv2.waitKey(0)
            # img2 = Image.open(filename).convert('RGB')
            class_encoding = color_encoding = OrderedDict([
                ('unlabeled', (0, 0, 0)), ('road', (128, 64, 128)),
                ('sidewalk', (244, 35, 232)), ('building', (70, 70, 70)),
                ('wall', (102, 102, 156)), ('fence', (190, 153, 153)),
                ('pole', (153, 153, 153)), ('traffic_light', (250, 170, 30)),
                ('traffic_sign', (220, 220, 0)),
                ('vegetation', (107, 142, 35)), ('terrain', (152, 251, 152)),
                ('sky', (70, 130, 180)), ('person', (220, 20, 60)),
                ('rider', (255, 0, 0)), ('car', (0, 0, 142)),
                ('truck', (0, 0, 70)), ('bus', (0, 60, 100)),
                ('train', (0, 80, 100)), ('motorcycle', (0, 0, 230)),
                ('bicycle', (119, 11, 32))
            ])

            num_classes = len(class_encoding)
            model_path = os.path.join(args.save_dir, args.name)
            checkpoint = torch.load(model_path)
            model = ERFNet(num_classes)
            model = model.cuda()
            model.load_state_dict(checkpoint['state_dict'])
            img = img.resize((args.width, args.height), Image.BILINEAR)
            start = time.time()
            images = transforms.ToTensor()(img)
            torch.reshape(images, (1, 3, args.width, args.height))
            images = images.unsqueeze(0)
            with torch.no_grad():
                images = images.cuda()
                predictions = model(images)
                end = time.time()
                print('model speed:', int(1 / (end - start)), "FPS")
                _, predictions = torch.max(predictions.data, 1)
                label_to_rgb = transforms.Compose([
                    utils.LongTensorToRGBPIL(class_encoding),
                    transforms.ToTensor()
                ])
                color_predictions = utils.batch_transform(
                    predictions.cpu(), label_to_rgb)
                end = time.time()
                print('model+transform:', int(1 / (end - start)), "FPS")
                utils.imshow_batch(images.data.cpu(), color_predictions)
        i += 1
Beispiel #13
0
def train(train_loader, val_loader, class_weights, class_encoding):
    print("Training...")
    num_classes = len(class_encoding)
    model = ERFNet(num_classes)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(),
                           lr=args.learning_rate,
                           weight_decay=args.weight_decay)
    # Learning rate decay scheduler
    lr_updater = lr_scheduler.StepLR(optimizer, args.lr_decay_epochs,
                                     args.lr_decay)

    # Evaluation metric
    if args.ignore_unlabeled:
        ignore_index = list(class_encoding).index('unlabeled')
    else:
        ignore_index = None

    metric = IoU(num_classes, ignore_index=ignore_index)

    if use_cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        model, optimizer, start_epoch, best_miou, val_miou, train_miou, val_loss, train_loss = utils.load_checkpoint(
            model, optimizer, args.save_dir, args.name, True)
        print(
            "Resuming from model: Start epoch = {0} | Best mean IoU = {1:.4f}".
            format(start_epoch, best_miou))
    else:
        start_epoch = 0
        best_miou = 0
        val_miou = []
        train_miou = []
        val_loss = []
        train_loss = []

    # Start Training
    train = Train(model, train_loader, optimizer, criterion, metric, use_cuda)
    val = Test(model, val_loader, criterion, metric, use_cuda)

    for epoch in range(start_epoch, args.epochs):
        print(">> [Epoch: {0:d}] Training".format(epoch))
        lr_updater.step()
        epoch_loss, (iou, miou) = train.run_epoch(args.print_step)
        print(
            ">> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".format(
                epoch, epoch_loss, miou))
        train_loss.append(epoch_loss)
        train_miou.append(miou)

        #preform a validation test
        if (epoch + 1) % 10 == 0 or epoch + 1 == args.epochs:
            print(">>>> [Epoch: {0:d}] Validation".format(epoch))
            loss, (iou, miou) = val.run_epoch(args.print_step)
            print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
                  format(epoch, loss, miou))
            val_loss.append(loss)
            val_miou.append(miou)
            # Print per class IoU on last epoch or if best iou
            if epoch + 1 == args.epochs or miou > best_miou:
                for key, class_iou in zip(class_encoding.keys(), iou):
                    print("{0}: {1:.4f}".format(key, class_iou))
            # Save the model if it's the best thus far
            if miou > best_miou:
                print("Best model thus far. Saving...")
                best_miou = miou
                utils.save_checkpoint(model, optimizer, epoch + 1, best_miou,
                                      val_miou, train_miou, val_loss,
                                      train_loss, args)

    return model, train_loss, train_miou, val_loss, val_miou
Beispiel #14
0
        loaders, w_class, class_encoding = load_dataset(dataset)
        train_loader, val_loader, test_loader = loaders

        if args.mode.lower() in {'train'}:
            model, tl, tmiou, vl, vmiou = train(train_loader, val_loader,
                                                w_class, class_encoding)
            plt.plot(tl, label="train loss")
            plt.plot(tmiou, label="train miou")
            plt.plot(vl, label="val loss")
            plt.plot(vmiou, label="val miou")
            plt.legend()
            plt.xlabel("Epoch")
            plt.ylabel("loss/accuracy")
            plt.grid(True)
            plt.xticks()
            plt.savefig('./plots/train.png')
        elif args.mode.lower() == 'test':
            num_classes = len(class_encoding)
            #model = ENet(num_classes)
            model = ERFNet(num_classes)
            if use_cuda:
                model = model.cuda()
            optimizer = optim.Adam(model.parameters())
            model = utils.load_checkpoint(model, optimizer, args.save_dir,
                                          args.name)[0]
            test(model, test_loader, w_class, class_encoding)
        else:
            raise RuntimeError(
                "\"{0}\" is not a valid choice for execution mode.".format(
                    args.mode))