Exemplo n.º 1
0
def load_data(datadir):
    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    base_size = 320
    crop_size = 256

    min_size = int(0.5 * base_size)
    max_size = int(2.0 * base_size)

    print("Loading training data")
    st = time.time()
    dataset = VOCSegmentation(datadir,
                              image_set='train',
                              download=True,
                              transforms=Compose([
                                  RandomResize(min_size, max_size),
                                  RandomCrop(crop_size),
                                  RandomHorizontalFlip(0.5),
                                  SampleTransform(
                                      transforms.ColorJitter(brightness=0.3,
                                                             contrast=0.3,
                                                             saturation=0.1,
                                                             hue=0.02)),
                                  ToTensor(),
                                  SampleTransform(normalize)
                              ]))

    print("Took", time.time() - st)

    print("Loading validation data")
    st = time.time()
    dataset_test = VOCSegmentation(datadir,
                                   image_set='val',
                                   download=True,
                                   transforms=Compose([
                                       RandomResize(base_size, base_size),
                                       ToTensor(),
                                       SampleTransform(normalize)
                                   ]))

    print("Took", time.time() - st)

    return dataset, dataset_test
Exemplo n.º 2
0
 def get_training_loader(img_root, label_root, file_list, batch_size,
                         img_height, img_width, num_class):
     transformed_dataset = VOCTestDataset(
         img_root,
         label_root,
         file_list,
         transform=transforms.Compose([
             RandomHorizontalFlip(),
             Resize((img_height + 5, img_width + 5)),
             RandomCrop((img_height, img_width)),
             ToTensor(),
             Normalize(imagenet_stats['mean'], imagenet_stats['std']),
             # GenOneHotLabel(num_class),
         ]))
     loader = DataLoader(
         transformed_dataset,
         batch_size,
         shuffle=True,
         num_workers=0,
         pin_memory=False,
     )
     return loader
Exemplo n.º 3
0
    def __init__(
            self,
            data,
            roi_size=64,
            zoom_range=(0.8, 1.25),
            samples_per_epoch=100000,
    ):
        self.data = data
        self.roi_size = roi_size

        rotation_pad_size = math.ceil(self.roi_size * (math.sqrt(2) - 1) / 2)
        padded_roi_size = roi_size + 2 * rotation_pad_size

        self.transforms = [
            RandomCrop(padded_roi_size),
            AffineTransform(zoom_range),
            RemovePadding(rotation_pad_size),
            RandomVerticalFlip(),
            RandomHorizontalFlip(),
        ]

        self.transforms = Compose(self.transforms)

        self.samples_per_epoch = samples_per_epoch
Exemplo n.º 4
0
def init(batch_size, state, split, input_sizes, sets_id, std, mean, keep_scale, reverse_channels, data_set,
         valtiny, no_aug):
    # Return data_loaders/data_loader
    # depending on whether the split is
    # 1: semi-supervised training
    # 2: fully-supervised training
    # 3: Just testing

    # Transformations (compatible with unlabeled data/pseudo labeled data)
    # ! Can't use torchvision.Transforms.Compose
    if data_set == 'voc':
        base = base_voc
        workers = 4
        transform_train = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
             RandomCrop(size=input_sizes[0]),
             RandomHorizontalFlip(flip_prob=0.5),
             Normalize(mean=mean, std=std)])
        if no_aug:
            transform_train_pseudo = Compose(
                [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
                 Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
                 Normalize(mean=mean, std=std)])
        else:
            transform_train_pseudo = Compose(
                [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
                 RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
                 RandomCrop(size=input_sizes[0]),
                 RandomHorizontalFlip(flip_prob=0.5),
                 Normalize(mean=mean, std=std)])
        transform_pseudo = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
             Normalize(mean=mean, std=std)])
        transform_test = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             ZeroPad(size=input_sizes[2]),
             Normalize(mean=mean, std=std)])
    elif data_set == 'city':  # All the same size (whole set is down-sampled by 2)
        base = base_city
        workers = 8
        transform_train = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
             RandomCrop(size=input_sizes[0]),
             RandomHorizontalFlip(flip_prob=0.5),
             Normalize(mean=mean, std=std),
             LabelMap(label_id_map_city)])
        if no_aug:
            transform_train_pseudo = Compose(
                [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
                 Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
                 Normalize(mean=mean, std=std)])
        else:
            transform_train_pseudo = Compose(
                [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
                 RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
                 RandomCrop(size=input_sizes[0]),
                 RandomHorizontalFlip(flip_prob=0.5),
                 Normalize(mean=mean, std=std)])
        transform_pseudo = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
             Normalize(mean=mean, std=std),
             LabelMap(label_id_map_city)])
        transform_test = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
             Normalize(mean=mean, std=std),
             LabelMap(label_id_map_city)])
    else:
        base = ''

    # Not the actual test set (i.e.validation set)
    test_set = StandardSegmentationDataset(root=base, image_set='valtiny' if valtiny else 'val',
                                           transforms=transform_test, label_state=0, data_set=data_set)
    val_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size, num_workers=workers, shuffle=False)

    # Testing
    if state == 3:
        return val_loader
    else:
        # Fully-supervised training
        if state == 2:
            labeled_set = StandardSegmentationDataset(root=base, image_set=(str(split) + '_labeled_' + str(sets_id)),
                                                      transforms=transform_train, label_state=0, data_set=data_set)
            labeled_loader = torch.utils.data.DataLoader(dataset=labeled_set, batch_size=batch_size,
                                                         num_workers=workers, shuffle=True)
            return labeled_loader, val_loader

        # Semi-supervised training
        elif state == 1:
            pseudo_labeled_set = StandardSegmentationDataset(root=base, data_set=data_set,
                                                             image_set=(str(split) + '_unlabeled_' + str(sets_id)),
                                                             transforms=transform_train_pseudo, label_state=1)
            reference_set = SegmentationLabelsDataset(root=base, image_set=(str(split) + '_unlabeled_' + str(sets_id)),
                                                      data_set=data_set)
            reference_loader = torch.utils.data.DataLoader(dataset=reference_set, batch_size=batch_size,
                                                           num_workers=workers, shuffle=False)
            unlabeled_set = StandardSegmentationDataset(root=base, data_set=data_set,
                                                        image_set=(str(split) + '_unlabeled_' + str(sets_id)),
                                                        transforms=transform_pseudo, label_state=2)
            labeled_set = StandardSegmentationDataset(root=base, data_set=data_set,
                                                      image_set=(str(split) + '_labeled_' + str(sets_id)),
                                                      transforms=transform_train, label_state=0)

            unlabeled_loader = torch.utils.data.DataLoader(dataset=unlabeled_set, batch_size=batch_size,
                                                           num_workers=workers, shuffle=False)

            pseudo_labeled_loader = torch.utils.data.DataLoader(dataset=pseudo_labeled_set,
                                                                batch_size=int(batch_size / 2),
                                                                num_workers=workers, shuffle=True)
            labeled_loader = torch.utils.data.DataLoader(dataset=labeled_set,
                                                         batch_size=int(batch_size / 2),
                                                         num_workers=workers, shuffle=True)
            return labeled_loader, pseudo_labeled_loader, unlabeled_loader, val_loader, reference_loader

        else:
            # Support unsupervised learning here if that's what you want
            raise ValueError
Exemplo n.º 5
0
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from dataset import FaceLandmarksDataset
from transforms import Rescale, RandomCrop, ToTensor

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode

transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
                                           root_dir='data/faces/',
                                           transform=transforms.Compose([
                                               Rescale(256),
                                               RandomCrop(224),
                                               ToTensor()
                                           ]))

for i in range(len(transformed_dataset)):
    sample = transformed_dataset[i]

    print(i, sample['image'].size(), sample['landmarks'].size())

    if i == 3:
        break

dataloader = DataLoader(transformed_dataset, batch_size=4,
                        shuffle=True, num_workers=4)

Exemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser()

    arg = parser.add_argument
    arg('--jaccard-weight', default=1, type=float)
    arg('--device-ids',
        type=str,
        default='0',
        help='For example 0,1 to run on two GPUs')
    arg('--fold', type=int, help='fold', default=0)
    arg('--root', default='runs/debug', help='checkpoint root')
    arg('--batch-size', type=int, default=8)
    arg('--n-epochs', type=int, default=14)
    arg('--lr', type=float, default=0.000001)
    arg('--workers', type=int, default=8)
    arg('--type',
        type=str,
        default='binary',
        choices=['binary', 'parts', 'instruments'])
    arg('--model',
        type=str,
        default='TernausNet',
        choices=['UNet', 'UNet11', 'LinkNet34', 'TernausNet'])

    args = parser.parse_args()

    root = Path(args.root)
    root.mkdir(exist_ok=True, parents=True)

    if args.type == 'parts':
        num_classes = 3
    elif args.type == 'instruments':
        num_classes = 8
    else:
        num_classes = 1

    if args.model == 'TernausNet':
        model = TernausNet34(num_classes=num_classes)
    else:
        model = TernausNet34(num_classes=num_classes)

    if torch.cuda.is_available():
        if args.device_ids:
            device_ids = list(map(int, args.device_ids.split(',')))
        else:
            device_ids = None
        model = nn.DataParallel(model, device_ids=device_ids).cuda()

    if args.type == 'binary':
        loss = LossBinary(jaccard_weight=args.jaccard_weight)
    else:
        loss = LossMulti(num_classes=num_classes,
                         jaccard_weight=args.jaccard_weight)

    cudnn.benchmark = True

    def make_loader(file_names,
                    shuffle=False,
                    transform=None,
                    mode='train',
                    problem_type='binary'):
        return DataLoader(dataset=MapDataset(file_names,
                                             transform=transform,
                                             problem_type=problem_type,
                                             mode=mode),
                          shuffle=shuffle,
                          num_workers=args.workers,
                          batch_size=args.batch_size,
                          pin_memory=torch.cuda.is_available())

    # labels = pd.read_csv('data/stage1_train_labels.csv')
    # labels = os.listdir('data/stage1_train_')
    # train_file_names, val_file_names = train_test_split(labels, test_size=0.2, random_state=42)

    # print('num train = {}, num_val = {}'.format(len(train_file_names), len(val_file_names)))

    # train_transform = DualCompose([
    #     HorizontalFlip(),
    #     VerticalFlip(),
    #     RandomCrop([256, 256]),
    #     RandomRotate90(),
    #     ShiftScaleRotate(),
    #     ImageOnly(RandomHueSaturationValue()),
    #     ImageOnly(RandomBrightness()),
    #     ImageOnly(RandomContrast()),
    #     ImageOnly(Normalize())
    # ])
    train_transform = DualCompose([
        OneOrOther(*(OneOf([
            Distort1(distort_limit=0.05, shift_limit=0.05),
            Distort2(num_steps=2, distort_limit=0.05)
        ]),
                     ShiftScaleRotate(shift_limit=0.0625,
                                      scale_limit=0.10,
                                      rotate_limit=45)),
                   prob=0.5),
        RandomRotate90(),
        RandomCrop([256, 256]),
        RandomFlip(prob=0.5),
        Transpose(prob=0.5),
        ImageOnly(RandomContrast(limit=0.2, prob=0.5)),
        ImageOnly(RandomFilter(limit=0.5, prob=0.2)),
        ImageOnly(RandomHueSaturationValue(prob=0.2)),
        ImageOnly(RandomBrightness()),
        ImageOnly(Normalize())
    ])

    val_transform = DualCompose([
        # RandomCrop([256, 256]),
        Rescale([256, 256]),
        ImageOnly(Normalize())
    ])

    train_loader = make_loader(TRAIN_ANNOTATIONS_PATH,
                               shuffle=True,
                               transform=train_transform,
                               problem_type=args.type)
    valid_loader = make_loader(VAL_ANNOTATIONS_PATH,
                               transform=val_transform,
                               mode='valid',
                               problem_type=args.type)

    root.joinpath('params.json').write_text(
        json.dumps(vars(args), indent=True, sort_keys=True))

    if args.type == 'binary':
        valid = validation_binary
    else:
        valid = validation_multi

    utils.train(init_optimizer=lambda lr: Adam(model.parameters(), lr=lr),
                args=args,
                model=model,
                criterion=loss,
                train_loader=train_loader,
                valid_loader=valid_loader,
                validation=valid,
                fold=args.fold,
                num_classes=num_classes)
Exemplo n.º 7
0
def main(args):

    print(args)

    torch.backends.cudnn.benchmark = True

    # Data loading
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    base_size = 320
    crop_size = 256
    min_size, max_size = int(0.5 * base_size), int(2.0 * base_size)

    interpolation_mode = InterpolationMode.BILINEAR

    train_loader, val_loader = None, None
    if not args.test_only:
        st = time.time()
        train_set = VOCSegmentation(args.data_path,
                                    image_set='train',
                                    download=True,
                                    transforms=Compose([
                                        RandomResize(min_size, max_size,
                                                     interpolation_mode),
                                        RandomCrop(crop_size),
                                        RandomHorizontalFlip(0.5),
                                        ImageTransform(
                                            T.ColorJitter(brightness=0.3,
                                                          contrast=0.3,
                                                          saturation=0.1,
                                                          hue=0.02)),
                                        ToTensor(),
                                        ImageTransform(normalize)
                                    ]))

        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            drop_last=True,
            sampler=RandomSampler(train_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(f"Training set loaded in {time.time() - st:.2f}s "
              f"({len(train_set)} samples in {len(train_loader)} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target, ignore_index=255)
        return

    if not (args.lr_finder or args.check_setup):
        st = time.time()
        val_set = VOCSegmentation(args.data_path,
                                  image_set='val',
                                  download=True,
                                  transforms=Compose([
                                      Resize((crop_size, crop_size),
                                             interpolation_mode),
                                      ToTensor(),
                                      ImageTransform(normalize)
                                  ]))

        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=args.batch_size,
            drop_last=False,
            sampler=SequentialSampler(val_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(
            f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)"
        )

    if args.source.lower() == 'holocron':
        model = segmentation.__dict__[args.arch](args.pretrained,
                                                 num_classes=len(VOC_CLASSES))
    elif args.source.lower() == 'torchvision':
        model = tv_segmentation.__dict__[args.arch](
            args.pretrained, num_classes=len(VOC_CLASSES))

    # Loss setup
    loss_weight = None
    if isinstance(args.bg_factor, float) and args.bg_factor != 1:
        loss_weight = torch.ones(len(VOC_CLASSES))
        loss_weight[0] = args.bg_factor
    if args.loss == 'crossentropy':
        criterion = nn.CrossEntropyLoss(weight=loss_weight,
                                        ignore_index=255,
                                        label_smoothing=args.label_smoothing)
    elif args.loss == 'focal':
        criterion = holocron.nn.FocalLoss(weight=loss_weight, ignore_index=255)
    elif args.loss == 'mc':
        criterion = holocron.nn.MutualChannelLoss(weight=loss_weight,
                                                  ignore_index=255,
                                                  xi=3)

    # Optimizer setup
    model_params = [p for p in model.parameters() if p.requires_grad]
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model_params,
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif args.opt == 'radam':
        optimizer = holocron.optim.RAdam(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adamp':
        optimizer = holocron.optim.AdamP(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adabelief':
        optimizer = holocron.optim.AdaBelief(model_params,
                                             args.lr,
                                             betas=(0.95, 0.99),
                                             eps=1e-6,
                                             weight_decay=args.weight_decay)

    log_wb = lambda metrics: wandb.log(metrics) if args.wb else None
    trainer = SegmentationTrainer(model,
                                  train_loader,
                                  val_loader,
                                  criterion,
                                  optimizer,
                                  args.device,
                                  args.output_file,
                                  num_classes=len(VOC_CLASSES),
                                  amp=args.amp,
                                  on_epoch_end=log_wb)
    if args.resume:
        print(f"Resuming {args.resume}")
        checkpoint = torch.load(args.resume, map_location='cpu')
        trainer.load(checkpoint)

    if args.show_preds:
        x, target = next(iter(train_loader))
        with torch.no_grad():
            if isinstance(args.device, int):
                x = x.cuda()
            trainer.model.eval()
            preds = trainer.model(x)
        plot_predictions(x.cpu(), preds.cpu(), target, ignore_index=255)
        return

    if args.test_only:
        print("Running evaluation")
        eval_metrics = trainer.evaluate()
        print(
            f"Validation loss: {eval_metrics['val_loss']:.4} (Mean IoU: {eval_metrics['mean_iou']:.2%})"
        )
        return

    if args.lr_finder:
        print("Looking for optimal LR")
        trainer.lr_find(args.freeze_until,
                        norm_weight_decay=args.norm_weight_decay,
                        num_it=min(len(train_loader), 100))
        trainer.plot_recorder()
        return

    if args.check_setup:
        print("Checking batch overfitting")
        is_ok = trainer.check_setup(args.freeze_until,
                                    args.lr,
                                    norm_weight_decay=args.norm_weight_decay,
                                    num_it=min(len(train_loader), 100))
        print(is_ok)
        return

    # Training monitoring
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    exp_name = f"{args.arch}-{current_time}" if args.name is None else args.name

    # W&B
    if args.wb:

        run = wandb.init(name=exp_name,
                         project="holocron-semantic-segmentation",
                         config={
                             "learning_rate": args.lr,
                             "scheduler": args.sched,
                             "weight_decay": args.weight_decay,
                             "epochs": args.epochs,
                             "batch_size": args.batch_size,
                             "architecture": args.arch,
                             "source": args.source,
                             "input_size": 256,
                             "optimizer": args.opt,
                             "dataset": "Pascal VOC2012 Segmentation",
                             "loss": args.loss,
                         })

    print("Start training")
    start_time = time.time()
    trainer.fit_n_epochs(args.epochs,
                         args.lr,
                         args.freeze_until,
                         args.sched,
                         norm_weight_decay=args.norm_weight_decay)
    total_time_str = str(
        datetime.timedelta(seconds=int(time.time() - start_time)))
    print(f"Training time {total_time_str}")

    if args.wb:
        run.finish()
Exemplo n.º 8
0
from config import config
import torch.nn as nn
import os

trainloader = DataLoader(LSPMPIILIP(
    config.train_json_file,
    config.train_img_dir,
    config.size,
    config.num_kpt,
    config.s,
    trans=Compose([
        RandomAddColorNoise(config.num_kpt, config.min_gauss, config.max_gauss,
                            config.percentage),
        RandomBrightnessContrast(config.num_kpt, config.min_alpha,
                                 config.max_alpha),
        RandomCrop(config.num_kpt, config.ratio_max_x, config.ratio_max_y,
                   config.center_perturb_max),
        RandomRotate(config.num_kpt, config.max_degree),
        RandomHorizontalFlip(config.num_kpt, config.prob),
        Resized(config.num_kpt, config.size),
    ])),
                         batch_size=config.batch_size,
                         shuffle=True,
                         num_workers=config.num_workers,
                         pin_memory=True)

model = build_model(config)

heat_criterion = Focal2DLoss()
offset_criterion = nn.SmoothL1Loss()
if config.cuda:
    heat_criterion = heat_criterion.cuda()
def init(batch_size, state, input_sizes, std, mean, dataset, city_aug=0):
    # Return data_loaders
    # depending on whether the state is
    # 1: training
    # 2: just testing

    # Transformations
    # ! Can't use torchvision.Transforms.Compose
    if dataset == 'voc':
        base = base_voc
        workers = 4
        transform_train = Compose([
            ToTensor(),
            RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
            RandomCrop(size=input_sizes[0]),
            RandomHorizontalFlip(flip_prob=0.5),
            Normalize(mean=mean, std=std)
        ])
        transform_test = Compose([
            ToTensor(),
            ZeroPad(size=input_sizes[2]),
            Normalize(mean=mean, std=std)
        ])
    elif dataset == 'city' or dataset == 'gtav' or dataset == 'synthia':  # All the same size
        if dataset == 'city':
            base = base_city
        elif dataset == 'gtav':
            base = base_gtav
        else:
            base = base_synthia
        outlier = False if dataset == 'city' else True  # GTAV has f****d up label ID
        workers = 8

        if city_aug == 3:  # SYNTHIA & GTAV
            if dataset == 'gtav':
                transform_train = Compose([
                    ToTensor(),
                    Resize(size_label=input_sizes[1],
                           size_image=input_sizes[1]),
                    RandomCrop(size=input_sizes[0]),
                    RandomHorizontalFlip(flip_prob=0.5),
                    Normalize(mean=mean, std=std),
                    LabelMap(label_id_map_city, outlier=outlier)
                ])
            else:
                transform_train = Compose([
                    ToTensor(),
                    RandomCrop(size=input_sizes[0]),
                    RandomHorizontalFlip(flip_prob=0.5),
                    Normalize(mean=mean, std=std),
                    LabelMap(label_id_map_synthia, outlier=outlier)
                ])
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
                Normalize(mean=mean, std=std),
                LabelMap(label_id_map_city)
            ])
        elif city_aug == 2:  # ERFNet
            transform_train = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
                LabelMap(label_id_map_city, outlier=outlier),
                RandomTranslation(trans_h=2, trans_w=2),
                RandomHorizontalFlip(flip_prob=0.5)
            ])
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[0], size_label=input_sizes[2]),
                LabelMap(label_id_map_city)
            ])
        elif city_aug == 1:  # City big
            transform_train = Compose([
                ToTensor(),
                RandomCrop(size=input_sizes[0]),
                LabelMap(label_id_map_city, outlier=outlier),
                RandomTranslation(trans_h=2, trans_w=2),
                RandomHorizontalFlip(flip_prob=0.5),
                Normalize(mean=mean, std=std)
            ])
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
                Normalize(mean=mean, std=std),
                LabelMap(label_id_map_city)
            ])
        else:  # Standard city
            transform_train = Compose([
                ToTensor(),
                RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
                RandomCrop(size=input_sizes[0]),
                RandomHorizontalFlip(flip_prob=0.5),
                Normalize(mean=mean, std=std),
                LabelMap(label_id_map_city, outlier=outlier)
            ])
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
                Normalize(mean=mean, std=std),
                LabelMap(label_id_map_city)
            ])
    else:
        raise ValueError

    # Not the actual test set (i.e. validation set)
    test_set = StandardSegmentationDataset(
        root=base_city if dataset == 'gtav' or dataset == 'synthia' else base,
        image_set='val',
        transforms=transform_test,
        data_set='city'
        if dataset == 'gtav' or dataset == 'synthia' else dataset)
    if city_aug == 1:
        val_loader = torch.utils.data.DataLoader(dataset=test_set,
                                                 batch_size=1,
                                                 num_workers=workers,
                                                 shuffle=False)
    else:
        val_loader = torch.utils.data.DataLoader(dataset=test_set,
                                                 batch_size=batch_size,
                                                 num_workers=workers,
                                                 shuffle=False)

    # Testing
    if state == 1:
        return val_loader
    else:
        # Training
        train_set = StandardSegmentationDataset(
            root=base,
            image_set='trainaug' if dataset == 'voc' else 'train',
            transforms=transform_train,
            data_set=dataset)
        train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                                   batch_size=batch_size,
                                                   num_workers=workers,
                                                   shuffle=True)
        return train_loader, val_loader
Exemplo n.º 10
0
def main(args):

    print(args)

    torch.backends.cudnn.benchmark = True

    # Data loading
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    base_size = 320
    crop_size = 256
    min_size, max_size = int(0.5 * base_size), int(2.0 * base_size)

    train_loader, val_loader = None, None
    if not args.test_only:
        st = time.time()
        train_set = VOCSegmentation(args.data_path,
                                    image_set='train',
                                    download=True,
                                    transforms=Compose([
                                        RandomResize(min_size, max_size),
                                        RandomCrop(crop_size),
                                        RandomHorizontalFlip(0.5),
                                        ImageTransform(
                                            T.ColorJitter(brightness=0.3,
                                                          contrast=0.3,
                                                          saturation=0.1,
                                                          hue=0.02)),
                                        ToTensor(),
                                        ImageTransform(normalize)
                                    ]))

        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            drop_last=True,
            sampler=RandomSampler(train_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(f"Training set loaded in {time.time() - st:.2f}s "
              f"({len(train_set)} samples in {len(train_loader)} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target, ignore_index=255)
        return

    if not (args.lr_finder or args.check_setup):
        st = time.time()
        val_set = VOCSegmentation(args.data_path,
                                  image_set='val',
                                  download=True,
                                  transforms=Compose([
                                      Resize((crop_size, crop_size)),
                                      ToTensor(),
                                      ImageTransform(normalize)
                                  ]))

        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=args.batch_size,
            drop_last=False,
            sampler=SequentialSampler(val_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(
            f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)"
        )

    model = segmentation.__dict__[args.model](
        args.pretrained,
        not (args.pretrained),
        num_classes=len(VOC_CLASSES),
    )

    # Loss setup
    loss_weight = None
    if isinstance(args.bg_factor, float):
        loss_weight = torch.ones(len(VOC_CLASSES))
        loss_weight[0] = args.bg_factor
    if args.loss == 'crossentropy':
        criterion = nn.CrossEntropyLoss(weight=loss_weight, ignore_index=255)
    elif args.loss == 'label_smoothing':
        criterion = holocron.nn.LabelSmoothingCrossEntropy(weight=loss_weight,
                                                           ignore_index=255)
    elif args.loss == 'focal':
        criterion = holocron.nn.FocalLoss(weight=loss_weight, ignore_index=255)
    elif args.loss == 'mc':
        criterion = holocron.nn.MutualChannelLoss(weight=loss_weight,
                                                  ignore_index=255)

    # Optimizer setup
    model_params = [p for p in model.parameters() if p.requires_grad]
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model_params,
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model_params,
                                     args.lr,
                                     betas=(0.95, 0.99),
                                     eps=1e-6,
                                     weight_decay=args.weight_decay)
    elif args.opt == 'radam':
        optimizer = holocron.optim.RAdam(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adamp':
        optimizer = holocron.optim.AdamP(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adabelief':
        optimizer = holocron.optim.AdaBelief(model_params,
                                             args.lr,
                                             betas=(0.95, 0.99),
                                             eps=1e-6,
                                             weight_decay=args.weight_decay)

    trainer = SegmentationTrainer(model,
                                  train_loader,
                                  val_loader,
                                  criterion,
                                  optimizer,
                                  args.device,
                                  args.output_file,
                                  num_classes=len(VOC_CLASSES))
    if args.resume:
        print(f"Resuming {args.resume}")
        checkpoint = torch.load(args.resume, map_location='cpu')
        trainer.load(checkpoint)

    if args.show_preds:
        x, target = next(iter(train_loader))
        with torch.no_grad():
            if isinstance(args.device, int):
                x = x.cuda()
            trainer.model.eval()
            preds = trainer.model(x)
        plot_predictions(x.cpu(), preds.cpu(), target, ignore_index=255)
        return

    if args.test_only:
        print("Running evaluation")
        eval_metrics = trainer.evaluate()
        print(
            f"Validation loss: {eval_metrics['val_loss']:.4} (Mean IoU: {eval_metrics['mean_iou']:.2%})"
        )
        return

    if args.lr_finder:
        print("Looking for optimal LR")
        trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100))
        trainer.plot_recorder()
        return

    if args.check_setup:
        print("Checking batch overfitting")
        is_ok = trainer.check_setup(args.freeze_until,
                                    args.lr,
                                    num_it=min(len(train_loader), 100))
        print(is_ok)
        return

    print("Start training")
    start_time = time.time()
    trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched)
    total_time_str = str(
        datetime.timedelta(seconds=int(time.time() - start_time)))
    print(f"Training time {total_time_str}")
Exemplo n.º 11
0
    json_file_name = str(data_root / "labels.json")
    train_file_names, val_file_names = get_color_file_names_both_cam(
        root=data_root)

    lr = 2.0e-4
    gaussian_std = 0.05
    n_epochs = 600
    # scale = 4

    img_width = 256  #1280 // scale
    img_height = 256  #1024 // scale
    offset = 30
    loss_ratio = 0.5
    train_transform = DualCompose([
        Resize(w=img_width + offset, h=img_height + offset),
        RandomCrop(size=(img_height, img_width)),
        HorizontalFlip(),
        VerticalFlip(),
        RandomColorDual(limit=0.3, prob=1.0),
        # RandomBrightnessDual(limit=0.3),
        # RandomContrastDual(limit=0.3),
        # RandomSaturationDual(limit=0.3),
        Normalize(normalize_mask=True)
    ])

    valid_transform = DualCompose(
        [Resize(w=img_width, h=img_height),
         Normalize(normalize_mask=True)])

    train_dataset = Challenge2018ColorizationDataset(
        image_file_names=train_file_names,
Exemplo n.º 12
0
def train_test_single_img(hr_img_path, lr_img_path, config):
    sr_factor = config["sr_factor"]
    img = Image.open(hr_img_path)
    lr_img = Image.open(lr_img_path)
    print(f"Starting training on {hr_img_path} with resolution factor {sr_factor}")
    dataset = ZSSRDataset.from_image(lr_img, config["sr_factor"])
    data_sampler = ZSSRSampler(dataset)
    if config["model"] == "zssr":
        model = ZSSRModel()
    else:
        vdsr_backbone = torch.load("./models/model_epoch_100.pth")
        vdsr_backbone.to(config["device"])
        model = ZSSRModelWithBackbone(vdsr_backbone, True)
    model.to(config["device"])
    model.train()
    all_models = None
    trans = transforms.Compose([
        ToTensor(),
        RandomCrop(config["crop_size"])
    ])
    if config["upsample"] == "pixelshuffle":
        trans = transforms.Compose([
            ToTensor(),
            # RandomCrop(config["crop_size"])
        ])
        upsamples_layers = []
        for _ in range(int(math.log(config["sr_factor"], 2))):
            upsamples_layers += [nn.Conv2d(3, 3 * 4, kernel_size=3, padding=3 // 2, bias=True),
                                 nn.ReLU(),
                                 nn.PixelShuffle(2)]
        upsample_model = nn.Sequential(*upsamples_layers)
        upsample_model.to(config["device"])
        upsample_model.train()
        upsample_fn = functools.partial(pixelshuffle_upsample, upsample_model=upsample_model)
        # we do some of the computation in the upsample phase
        model = ZSSRModel(layers_num=6 - int(math.log(sr_factor)))
        model.to(config["device"])
        model.train()
        all_models = nn.ModuleList([upsample_model, model]).to(config["device"])
    elif config["upsample"] == "cubic":
        all_models = model
        upsample_fn = bicubic_upsample

    optimizer = torch.optim.Adam(all_models.parameters(), lr=config["learning_rate"])

    num_batches = config["number_of_iterations"]

    train_single_img(lr_img,
                     model,
                     upsample_fn,
                     data_sampler,
                     trans,
                     optimizer,
                     num_batches)
    model.eval()
    if config["upsample"] != "cubic":
        upsample_model.eval()

    metrics = test_single_img(model,
                    upsample_fn,
                    lr_img,
                    img,
                    config["sr_factor"])
    return metrics
Exemplo n.º 13
0
    if single_block:
        boxes = boxes[:1]
    boxes = torch.as_tensor(boxes).reshape(-1, 4)  # guard against no boxes
    target = BoxList(boxes, (width, height), mode="xyxy")

    classes = torch.tensor([1] * len(boxes), dtype=torch.int32)
    target.add_field('labels', classes)

    return target


if __name__ == '__main__':
    from transforms import PadToDivisibility, RandomCrop
    from matplotlib import pyplot as plt
    from PIL import Image, ImageDraw

    transform = RandomCrop(32, 32)
    dataset = Scarlet300Dataset(split='train', transforms=transform)

    for i in range(5):
        image, target = dataset[i]

        draw = ImageDraw.Draw(image)
        for xy in target.convert('xyxy').bbox:
            x0, y0, x1, y1 = xy[0], xy[1], xy[2], xy[3]
            draw.rectangle((x0, y0, x1, y1), outline=(255, 0, 0))

        plt.figure()
        plt.imshow(image)

    plt.show()