Example #1
0
def cityscapes_train(resize_height, resize_width, crop_height, crop_width,
                     batch_size, num_workers):
    """A loader that loads images and ground truth for segmentation from the
    cityscapes training set.
    """

    labels = labels_cityscape_seg.getlabels()
    num_classes = len(labels_cityscape_seg.gettrainid2label())

    transforms = [
        tf.RandomHorizontalFlip(),
        tf.CreateScaledImage(),
        tf.Resize((resize_height, resize_width)),
        tf.RandomRescale(1.5),
        tf.RandomCrop((crop_height, crop_width)),
        tf.ConvertSegmentation(),
        tf.CreateColoraug(new_element=True),
        tf.ColorJitter(brightness=0.2,
                       contrast=0.2,
                       saturation=0.2,
                       hue=0.1,
                       gamma=0.0),
        tf.RemoveOriginals(),
        tf.ToTensor(),
        tf.NormalizeZeroMean(),
        tf.AddKeyValue('domain', 'cityscapes_train_seg'),
        tf.AddKeyValue('purposes', ('segmentation', 'domain')),
        tf.AddKeyValue('num_classes', num_classes)
    ]

    dataset_name = 'cityscapes'

    dataset = StandardDataset(dataset=dataset_name,
                              trainvaltest_split='train',
                              video_mode='mono',
                              stereo_mode='mono',
                              labels_mode='fromid',
                              disable_const_items=True,
                              labels=labels,
                              keys_to_load=('color', 'segmentation'),
                              data_transforms=transforms,
                              video_frames=(0, ))

    loader = DataLoader(dataset,
                        batch_size,
                        True,
                        num_workers=num_workers,
                        pin_memory=True,
                        drop_last=True)

    print(
        f"  - Can use {len(dataset)} images from the cityscapes train set for segmentation training",
        flush=True)

    return loader
    def __init__(self, options):

        print(" -> Executing script", os.path.basename(__file__))

        self.opt = options
        self.device = torch.device("cpu" if self.opt.no_cuda else "cuda")

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           LABELS AND CITIES
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        assert self.opt.train_set in {123, 1}, "Invalid train_set!"
        keys_to_load = ['color', 'segmentation']

        # Labels
        if self.opt.train_set == 1:
            labels = labels_cityscape_seg_train1.getlabels()
        else:
            labels = labels_cityscape_seg_train3_eval.getlabels()

        # Train IDs
        self.train_ids = set([labels[i].trainId for i in range(len(labels))])
        self.train_ids.remove(255)

        self.num_classes = len(self.train_ids)

        # Apply city filter
        folders_to_train = CitySet.get_city_set(0)
        if self.opt.city:
            folders_to_train = CitySet.get_city_set(self.opt.train_set)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           DATASET DEFINITIONS
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # Data augmentation
        train_data_transforms = [
            mytransforms.RandomHorizontalFlip(),
            mytransforms.CreateScaledImage(),
            mytransforms.Resize((self.opt.height, self.opt.width),
                                image_types=keys_to_load),
            mytransforms.RandomRescale(1.5),
            mytransforms.RandomCrop(
                (self.opt.crop_height, self.opt.crop_width)),
            mytransforms.ConvertSegmentation(),
            mytransforms.CreateColoraug(new_element=True,
                                        scales=self.opt.scales),
            mytransforms.ColorJitter(brightness=0.2,
                                     contrast=0.2,
                                     saturation=0.2,
                                     hue=0.1,
                                     gamma=0.0),
            mytransforms.RemoveOriginals(),
            mytransforms.ToTensor(),
            mytransforms.NormalizeZeroMean(),
        ]

        train_dataset = CityscapesDataset(
            dataset="cityscapes",
            trainvaltest_split='train',
            video_mode='mono',
            stereo_mode='mono',
            scales=self.opt.scales,
            labels_mode='fromid',
            labels=labels,
            keys_to_load=keys_to_load,
            data_transforms=train_data_transforms,
            video_frames=self.opt.video_frames,
            folders_to_load=folders_to_train,
        )

        self.train_loader = DataLoader(dataset=train_dataset,
                                       batch_size=self.opt.batch_size,
                                       shuffle=True,
                                       num_workers=self.opt.num_workers,
                                       pin_memory=True,
                                       drop_last=True)

        val_data_transforms = [
            mytransforms.CreateScaledImage(),
            mytransforms.Resize((self.opt.height, self.opt.width),
                                image_types=keys_to_load),
            mytransforms.ConvertSegmentation(),
            mytransforms.CreateColoraug(new_element=True,
                                        scales=self.opt.scales),
            mytransforms.RemoveOriginals(),
            mytransforms.ToTensor(),
            mytransforms.NormalizeZeroMean(),
        ]

        val_dataset = CityscapesDataset(
            dataset=self.opt.dataset,
            trainvaltest_split="train",
            video_mode='mono',
            stereo_mode='mono',
            scales=self.opt.scales,
            labels_mode='fromid',
            labels=labels,
            keys_to_load=keys_to_load,
            data_transforms=val_data_transforms,
            video_frames=self.opt.video_frames,
            folders_to_load=CitySet.get_city_set(-1))

        self.val_loader = DataLoader(dataset=val_dataset,
                                     batch_size=self.opt.batch_size,
                                     shuffle=False,
                                     num_workers=self.opt.num_workers,
                                     pin_memory=True,
                                     drop_last=True)

        self.val_iter = iter(self.val_loader)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           LOGGING OPTIONS
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        print(
            "++++++++++++++++++++++ INIT TRAINING ++++++++++++++++++++++++++")
        print("Using dataset:\n  ", self.opt.dataset, "with split",
              self.opt.dataset_split)
        print(
            "There are {:d} training items and {:d} validation items\n".format(
                len(train_dataset), len(val_dataset)))

        path_getter = GetPath()
        log_path = path_getter.get_checkpoint_path()
        self.log_path = os.path.join(log_path, 'erfnet', self.opt.model_name)

        self.writers = {}
        for mode in ["train", "validation"]:
            self.writers[mode] = SummaryWriter(
                os.path.join(self.log_path, mode))

        # Copy this file to log dir
        shutil.copy2(__file__, self.log_path)

        print("Training model named:\n  ", self.opt.model_name)
        print("Models and tensorboard events files are saved to:\n  ",
              self.log_path)
        print("Training is using:\n  ", self.device)
        print("Training takes place on train set:\n  ", self.opt.train_set)
        print(
            "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           MODEL DEFINITION
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        # Instantiate model
        self.model = ERFNet(self.num_classes, self.opt)
        self.model.to(self.device)
        self.parameters_to_train = self.model.parameters()

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           OPTIMIZER SET-UP
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        self.model_optimizer = optim.Adam(params=self.parameters_to_train,
                                          lr=self.opt.learning_rate,
                                          weight_decay=self.opt.weight_decay)
        lambda1 = lambda epoch: pow((1 -
                                     ((epoch - 1) / self.opt.num_epochs)), 0.9)
        self.model_lr_scheduler = optim.lr_scheduler.LambdaLR(
            self.model_optimizer, lr_lambda=lambda1)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           LOSSES
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        self.crossentropy = CrossEntropyLoss(ignore_background=True,
                                             device=self.device)
        self.crossentropy.to(self.device)

        self.metric_model = SegmentationRunningScore(self.num_classes)

        # Save all options to disk and print them to stdout
        self.save_opts(len(train_dataset), len(val_dataset))
        self._print_options()

        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        #                           EVALUATOR DEFINITION
        # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        if self.opt.validate:
            self.evaluator = Evaluator(self.opt, self.model)
Example #3
0
    dataset = 'cityscapes'
    trainvaltest_split = 'train'
    keys_to_load = ['color', 'depth', 'segmentation',
                    'camera_intrinsics']  # Optional; standard is just 'color'

    # The following parameters and the data_transforms list are optional. Standard is just the transform ToTensor()
    width = 640
    height = 192
    scales = [0, 1, 2, 3]
    data_transforms = [  #mytransforms.RandomExchangeStereo(),  # (color, 0, -1)
        mytransforms.RandomHorizontalFlip(),
        mytransforms.RandomVerticalFlip(),
        mytransforms.CreateScaledImage(),  # (color, 0, 0)
        mytransforms.RandomRotate(0.0),
        mytransforms.RandomTranslate(0),
        mytransforms.RandomRescale(scale=1.1, fraction=0.5),
        mytransforms.RandomCrop((320, 1088)),
        mytransforms.Resize((height, width)),
        mytransforms.MultiResize(scales),
        mytransforms.CreateColoraug(new_element=True,
                                    scales=scales),  # (color_aug, 0, 0)
        mytransforms.ColorJitter(brightness=0.2,
                                 contrast=0.2,
                                 saturation=0.2,
                                 hue=0.1,
                                 gamma=0.0),
        mytransforms.GaussianBlurr(fraction=0.5),
        mytransforms.RemoveOriginals(),
        mytransforms.ToTensor(),
        mytransforms.NormalizeZeroMean(),
    ]