Esempio n. 1
0
def load_data(datadir, img_size=416, crop_pct=0.875):
    # Data loading code
    print("Loading data")
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    scale_size = int(math.floor(img_size / crop_pct))

    print("Loading training data")
    st = time.time()
    dataset = VOCDetection(datadir, image_set='train', download=True,
                           transforms=Compose([VOCTargetTransform(classes),
                                               RandomResizedCrop((img_size, img_size), scale=(0.3, 1.0)),
                                               RandomHorizontalFlip(),
                                               convert_to_relative,
                                               ImageTransform(transforms.ColorJitter(brightness=0.3, contrast=0.3,
                                                                                     saturation=0.1, hue=0.02)),
                                               ImageTransform(transforms.ToTensor()), ImageTransform(normalize)]))

    print("Took", time.time() - st)

    print("Loading validation data")
    st = time.time()
    dataset_test = VOCDetection(datadir, image_set='val', download=True,
                                transforms=Compose([VOCTargetTransform(classes),
                                                    Resize(scale_size), CenterCrop(img_size),
                                                    convert_to_relative,
                                                    ImageTransform(transforms.ToTensor()), ImageTransform(normalize)]))

    print("Took", time.time() - st)
    print("Creating data loaders")
    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler
Esempio n. 2
0
def main():
    class imshowCollate(object):
        def __init__(self):
            pass

        def __call__(self, batch):
            images, labels = zip(*batch)
            idx = 0
            for img in images:
                img = img.cpu().numpy().transpose((1, 2, 0)) * 255  #totensor
                cv2.imwrite(
                    'datatest/sev_img/img' + str(idx) + '——' +
                    str(labels[idx]) + '.jpg', img)
                # print(img.shape)
                idx += 1
            return images, labels

    from transforms import  Compose, Normalize, RandomResizedCrop, RandomHorizontalFlip, \
        ColorJitter, ToTensor,Lighting

    batch_size = 16
    normalize = Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])

    dataset = FileListLabeledDataset(
        '/workspace/mnt/group/algo/yangdecheng/work/multi_task/pytorch-train/datatest/test.txt',
        '/workspace/mnt/group/algo/yangdecheng/work/multi_task/pytorch-train/datatest/pic',
        Compose([
            RandomResizedCrop((112),
                              scale=(0.7, 1.2),
                              ratio=(1. / 1., 4. / 1.)),
            RandomHorizontalFlip(),
            ColorJitter(brightness=[0.5, 1.5],
                        contrast=[0.5, 1.5],
                        saturation=[0.5, 1.5],
                        hue=0),
            ToTensor(),
            Lighting(1, [0.2175, 0.0188, 0.0045],
                     [[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140],
                      [-0.5836, -0.6948, 0.4203]]),  #0.1
            # 				normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=100,
                                               shuffle=True,
                                               num_workers=10,
                                               pin_memory=True,
                                               sampler=None,
                                               collate_fn=imshowCollate())

    from multiprocessing import Process
    p_list = []
    for i in range(1):
        p_list.append(Process(target=iter_f, args=(train_loader, )))
    for p in p_list:
        p.start()
    for p in p_list:
        p.join()
Esempio n. 3
0
def load_data(datadir):
    # Data loading code
    print("Loading data")
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    base_size = 320
    crop_size = 256

    min_size = int(0.5 * base_size)
    max_size = int(2.0 * base_size)

    print("Loading training data")
    st = time.time()
    dataset = VOCSegmentation(datadir,
                              image_set='train',
                              download=True,
                              transforms=Compose([
                                  RandomResize(min_size, max_size),
                                  RandomCrop(crop_size),
                                  RandomHorizontalFlip(0.5),
                                  SampleTransform(
                                      transforms.ColorJitter(brightness=0.3,
                                                             contrast=0.3,
                                                             saturation=0.1,
                                                             hue=0.02)),
                                  ToTensor(),
                                  SampleTransform(normalize)
                              ]))

    print("Took", time.time() - st)

    print("Loading validation data")
    st = time.time()
    dataset_test = VOCSegmentation(datadir,
                                   image_set='val',
                                   download=True,
                                   transforms=Compose([
                                       RandomResize(base_size, base_size),
                                       ToTensor(),
                                       SampleTransform(normalize)
                                   ]))

    print("Took", time.time() - st)
    print("Creating data loaders")
    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler
Esempio n. 4
0
def load_data(datadir, img_size=416, crop_pct=0.875):
    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    scale_size = int(math.floor(img_size / crop_pct))

    print("Loading training data")
    st = time.time()
    train_set = VOCDetection(datadir,
                             image_set='train',
                             download=True,
                             transforms=Compose([
                                 VOCTargetTransform(VOC_CLASSES),
                                 RandomResizedCrop((img_size, img_size),
                                                   scale=(0.3, 1.0)),
                                 RandomHorizontalFlip(), convert_to_relative,
                                 ImageTransform(
                                     transforms.ColorJitter(brightness=0.3,
                                                            contrast=0.3,
                                                            saturation=0.1,
                                                            hue=0.02)),
                                 ImageTransform(transforms.ToTensor()),
                                 ImageTransform(normalize)
                             ]))

    print("Took", time.time() - st)

    print("Loading validation data")
    st = time.time()
    val_set = VOCDetection(datadir,
                           image_set='val',
                           download=True,
                           transforms=Compose([
                               VOCTargetTransform(VOC_CLASSES),
                               Resize(scale_size),
                               CenterCrop(img_size), convert_to_relative,
                               ImageTransform(transforms.ToTensor()),
                               ImageTransform(normalize)
                           ]))

    print("Took", time.time() - st)

    return train_set, val_set
Esempio n. 5
0
def get_transform(train=True,fixsize=False,img_size=416,min_size=800,max_size=1333,
                  image_mean=None,image_std=None,advanced=False):
    if image_mean is None:image_mean = [0.485, 0.456, 0.406]
    if image_std is None:image_std = [0.229, 0.224, 0.225]
    if train:
        transforms = Compose(
            [
                Augment(advanced),
                ToTensor(),
                ResizeFixSize(img_size) if fixsize else ResizeMinMax(min_size, max_size),
                RandomHorizontalFlip(0.5),
                Normalize(image_mean,image_std)
            ])
    else:
        transforms = Compose(
            [
                ToTensor(),
                ResizeFixSize(img_size) if fixsize else ResizeMinMax(min_size, max_size),
                # RandomHorizontalFlip(0.5),
                Normalize(image_mean, image_std)
            ])
    return transforms
Esempio n. 6
0
 def get_training_loader(img_root, label_root, file_list, batch_size,
                         img_height, img_width, num_class):
     transformed_dataset = VOCTestDataset(
         img_root,
         label_root,
         file_list,
         transform=transforms.Compose([
             RandomHorizontalFlip(),
             Resize((img_height + 5, img_width + 5)),
             RandomCrop((img_height, img_width)),
             ToTensor(),
             Normalize(imagenet_stats['mean'], imagenet_stats['std']),
             # GenOneHotLabel(num_class),
         ]))
     loader = DataLoader(
         transformed_dataset,
         batch_size,
         shuffle=True,
         num_workers=0,
         pin_memory=False,
     )
     return loader
Esempio n. 7
0
    def __init__(
            self,
            data,
            roi_size=64,
            zoom_range=(0.8, 1.25),
            samples_per_epoch=100000,
    ):
        self.data = data
        self.roi_size = roi_size

        rotation_pad_size = math.ceil(self.roi_size * (math.sqrt(2) - 1) / 2)
        padded_roi_size = roi_size + 2 * rotation_pad_size

        self.transforms = [
            RandomCrop(padded_roi_size),
            AffineTransform(zoom_range),
            RemovePadding(rotation_pad_size),
            RandomVerticalFlip(),
            RandomHorizontalFlip(),
        ]

        self.transforms = Compose(self.transforms)

        self.samples_per_epoch = samples_per_epoch
Esempio n. 8
0
def get_transform_fixsize(train=True,img_size=416,
                  image_mean=None,image_std=None,advanced=False):
    if image_mean is None:image_mean = [0.485, 0.456, 0.406]
    if image_std is None:image_std = [0.229, 0.224, 0.225]
    if train:
        transforms = Compose(
            [
                Augment(advanced),
                Pad(),
                ToTensor(),
                Resize(img_size),
                RandomHorizontalFlip(0.5),
                Normalize(image_mean,image_std)
            ])
    else:
        transforms = Compose(
            [
                Pad(),
                ToTensor(),
                Resize(img_size),
                # RandomHorizontalFlip(0.5),
                Normalize(image_mean, image_std)
            ])
    return transforms
Esempio n. 9
0
def main(args):

    print(args)

    torch.backends.cudnn.benchmark = True

    # Data loading
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    crop_pct = 0.875
    scale_size = int(math.floor(args.img_size / crop_pct))

    train_loader, val_loader = None, None

    if not args.test_only:
        st = time.time()
        train_set = VOCDetection(datadir,
                                 image_set='train',
                                 download=True,
                                 transforms=Compose([
                                     VOCTargetTransform(VOC_CLASSES),
                                     RandomResizedCrop(
                                         (args.img_size, args.img_size),
                                         scale=(0.3, 1.0)),
                                     RandomHorizontalFlip(),
                                     convert_to_relative,
                                     ImageTransform(
                                         T.ColorJitter(brightness=0.3,
                                                       contrast=0.3,
                                                       saturation=0.1,
                                                       hue=0.02)),
                                     ImageTransform(T.ToTensor()),
                                     ImageTransform(normalize)
                                 ]))

        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            drop_last=True,
            collate_fn=collate_fn,
            sampler=RandomSampler(train_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(f"Training set loaded in {time.time() - st:.2f}s "
              f"({len(train_set)} samples in {len(train_loader)} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target)
        return

    if not (args.lr_finder or args.check_setup):
        st = time.time()
        val_set = VOCDetection(datadir,
                               image_set='val',
                               download=True,
                               transforms=Compose([
                                   VOCTargetTransform(VOC_CLASSES),
                                   Resize(scale_size),
                                   CenterCrop(args.img_size),
                                   convert_to_relative,
                                   ImageTransform(T.ToTensor()),
                                   ImageTransform(normalize)
                               ]))

        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=args.batch_size,
            drop_last=False,
            collate_fn=collate_fn,
            sampler=SequentialSampler(val_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(
            f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)"
        )

    model = detection.__dict__[args.model](args.pretrained,
                                           num_classes=len(VOC_CLASSES),
                                           pretrained_backbone=True)

    model_params = [p for p in model.parameters() if p.requires_grad]
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model_params,
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model_params,
                                     args.lr,
                                     betas=(0.95, 0.99),
                                     eps=1e-6,
                                     weight_decay=args.weight_decay)
    elif args.opt == 'radam':
        optimizer = holocron.optim.RAdam(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'ranger':
        optimizer = Lookahead(
            holocron.optim.RAdam(model_params,
                                 args.lr,
                                 betas=(0.95, 0.99),
                                 eps=1e-6,
                                 weight_decay=args.weight_decay))
    elif args.opt == 'tadam':
        optimizer = holocron.optim.TAdam(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)

    trainer = DetectionTrainer(model, train_loader, val_loader, None,
                               optimizer, args.device, args.output_file)

    if args.resume:
        print(f"Resuming {args.resume}")
        checkpoint = torch.load(args.resume, map_location='cpu')
        trainer.load(checkpoint)

    if args.test_only:
        print("Running evaluation")
        eval_metrics = trainer.evaluate()
        print(
            f"Loc error: {eval_metrics['loc_err']:.2%} | Clf error: {eval_metrics['clf_err']:.2%} | "
            f"Det error: {eval_metrics['det_err']:.2%}")
        return

    if args.lr_finder:
        print("Looking for optimal LR")
        trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100))
        trainer.plot_recorder()
        return

    if args.check_setup:
        print("Checking batch overfitting")
        is_ok = trainer.check_setup(args.freeze_until,
                                    args.lr,
                                    num_it=min(len(train_loader), 100))
        print(is_ok)
        return

    print("Start training")
    start_time = time.time()
    trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched)
    total_time_str = str(
        datetime.timedelta(seconds=int(time.time() - start_time)))
    print(f"Training time {total_time_str}")
Esempio n. 10
0
def main(args):

    print(args)

    torch.backends.cudnn.benchmark = True

    # Data loading
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    base_size = 320
    crop_size = 256
    min_size, max_size = int(0.5 * base_size), int(2.0 * base_size)

    interpolation_mode = InterpolationMode.BILINEAR

    train_loader, val_loader = None, None
    if not args.test_only:
        st = time.time()
        train_set = VOCSegmentation(args.data_path,
                                    image_set='train',
                                    download=True,
                                    transforms=Compose([
                                        RandomResize(min_size, max_size,
                                                     interpolation_mode),
                                        RandomCrop(crop_size),
                                        RandomHorizontalFlip(0.5),
                                        ImageTransform(
                                            T.ColorJitter(brightness=0.3,
                                                          contrast=0.3,
                                                          saturation=0.1,
                                                          hue=0.02)),
                                        ToTensor(),
                                        ImageTransform(normalize)
                                    ]))

        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            drop_last=True,
            sampler=RandomSampler(train_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(f"Training set loaded in {time.time() - st:.2f}s "
              f"({len(train_set)} samples in {len(train_loader)} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target, ignore_index=255)
        return

    if not (args.lr_finder or args.check_setup):
        st = time.time()
        val_set = VOCSegmentation(args.data_path,
                                  image_set='val',
                                  download=True,
                                  transforms=Compose([
                                      Resize((crop_size, crop_size),
                                             interpolation_mode),
                                      ToTensor(),
                                      ImageTransform(normalize)
                                  ]))

        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=args.batch_size,
            drop_last=False,
            sampler=SequentialSampler(val_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(
            f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)"
        )

    if args.source.lower() == 'holocron':
        model = segmentation.__dict__[args.arch](args.pretrained,
                                                 num_classes=len(VOC_CLASSES))
    elif args.source.lower() == 'torchvision':
        model = tv_segmentation.__dict__[args.arch](
            args.pretrained, num_classes=len(VOC_CLASSES))

    # Loss setup
    loss_weight = None
    if isinstance(args.bg_factor, float) and args.bg_factor != 1:
        loss_weight = torch.ones(len(VOC_CLASSES))
        loss_weight[0] = args.bg_factor
    if args.loss == 'crossentropy':
        criterion = nn.CrossEntropyLoss(weight=loss_weight,
                                        ignore_index=255,
                                        label_smoothing=args.label_smoothing)
    elif args.loss == 'focal':
        criterion = holocron.nn.FocalLoss(weight=loss_weight, ignore_index=255)
    elif args.loss == 'mc':
        criterion = holocron.nn.MutualChannelLoss(weight=loss_weight,
                                                  ignore_index=255,
                                                  xi=3)

    # Optimizer setup
    model_params = [p for p in model.parameters() if p.requires_grad]
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model_params,
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif args.opt == 'radam':
        optimizer = holocron.optim.RAdam(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adamp':
        optimizer = holocron.optim.AdamP(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adabelief':
        optimizer = holocron.optim.AdaBelief(model_params,
                                             args.lr,
                                             betas=(0.95, 0.99),
                                             eps=1e-6,
                                             weight_decay=args.weight_decay)

    log_wb = lambda metrics: wandb.log(metrics) if args.wb else None
    trainer = SegmentationTrainer(model,
                                  train_loader,
                                  val_loader,
                                  criterion,
                                  optimizer,
                                  args.device,
                                  args.output_file,
                                  num_classes=len(VOC_CLASSES),
                                  amp=args.amp,
                                  on_epoch_end=log_wb)
    if args.resume:
        print(f"Resuming {args.resume}")
        checkpoint = torch.load(args.resume, map_location='cpu')
        trainer.load(checkpoint)

    if args.show_preds:
        x, target = next(iter(train_loader))
        with torch.no_grad():
            if isinstance(args.device, int):
                x = x.cuda()
            trainer.model.eval()
            preds = trainer.model(x)
        plot_predictions(x.cpu(), preds.cpu(), target, ignore_index=255)
        return

    if args.test_only:
        print("Running evaluation")
        eval_metrics = trainer.evaluate()
        print(
            f"Validation loss: {eval_metrics['val_loss']:.4} (Mean IoU: {eval_metrics['mean_iou']:.2%})"
        )
        return

    if args.lr_finder:
        print("Looking for optimal LR")
        trainer.lr_find(args.freeze_until,
                        norm_weight_decay=args.norm_weight_decay,
                        num_it=min(len(train_loader), 100))
        trainer.plot_recorder()
        return

    if args.check_setup:
        print("Checking batch overfitting")
        is_ok = trainer.check_setup(args.freeze_until,
                                    args.lr,
                                    norm_weight_decay=args.norm_weight_decay,
                                    num_it=min(len(train_loader), 100))
        print(is_ok)
        return

    # Training monitoring
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    exp_name = f"{args.arch}-{current_time}" if args.name is None else args.name

    # W&B
    if args.wb:

        run = wandb.init(name=exp_name,
                         project="holocron-semantic-segmentation",
                         config={
                             "learning_rate": args.lr,
                             "scheduler": args.sched,
                             "weight_decay": args.weight_decay,
                             "epochs": args.epochs,
                             "batch_size": args.batch_size,
                             "architecture": args.arch,
                             "source": args.source,
                             "input_size": 256,
                             "optimizer": args.opt,
                             "dataset": "Pascal VOC2012 Segmentation",
                             "loss": args.loss,
                         })

    print("Start training")
    start_time = time.time()
    trainer.fit_n_epochs(args.epochs,
                         args.lr,
                         args.freeze_until,
                         args.sched,
                         norm_weight_decay=args.norm_weight_decay)
    total_time_str = str(
        datetime.timedelta(seconds=int(time.time() - start_time)))
    print(f"Training time {total_time_str}")

    if args.wb:
        run.finish()
Esempio n. 11
0
# import cv2 as cv
from torch.autograd import Variable
from utilities import build_class_names, draw_detection, confidence_threshold, max_box, nms, im2PIL, imshow
from transforms import RandomBlur, RandomHorizontalFlip, RandomVerticalFlip
from PIL import Image, ImageOps
from yolo_v1 import YOLOv1
from data.voc_dataset import VOCDataset

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=1, scale=(0.02, 0.33), ratio=(0.1, 0.1))
])

pair_transform = transforms.Compose([
    RandomHorizontalFlip(probability=0),
    RandomVerticalFlip(probability=1)
])

_IMAGE_SIZE_ = (448, 448)
_GRID_SIZE_ = 7
_STRIDE_ = _IMAGE_SIZE_[0] / _GRID_SIZE_
class_names = build_class_names("./voc.names")

dataset = VOCDataset(f"./data/val.txt", image_size=_IMAGE_SIZE_, grid_size=_GRID_SIZE_)

class_color_mapping = {
    0: "red", 1: "blue", 2: "AntiqueWhite", 3: "Aquamarine", 4: "Black",
    5: "SeaGreen", 6: "Chartreuse", 7: "Chocolate", 8:"MediumAquaMarine", 9: "DarkGoldenRod",
    10: "DarkGreen", 11: "DarkOrchid", 12: "DeepSkyBlue", 13: "DarkSlateGrey", 14: "DarkSalmon",
    15: "DimGrey", 16: "SlateBlue", 17: "Fuchsia", 18: "Gold", 19: "IndianRed"
Esempio n. 12
0
trainloader = DataLoader(LSPMPIILIP(
    config.train_json_file,
    config.train_img_dir,
    config.size,
    config.num_kpt,
    config.s,
    trans=Compose([
        RandomAddColorNoise(config.num_kpt, config.min_gauss, config.max_gauss,
                            config.percentage),
        RandomBrightnessContrast(config.num_kpt, config.min_alpha,
                                 config.max_alpha),
        RandomCrop(config.num_kpt, config.ratio_max_x, config.ratio_max_y,
                   config.center_perturb_max),
        RandomRotate(config.num_kpt, config.max_degree),
        RandomHorizontalFlip(config.num_kpt, config.prob),
        Resized(config.num_kpt, config.size),
    ])),
                         batch_size=config.batch_size,
                         shuffle=True,
                         num_workers=config.num_workers,
                         pin_memory=True)

model = build_model(config)

heat_criterion = Focal2DLoss()
offset_criterion = nn.SmoothL1Loss()
if config.cuda:
    heat_criterion = heat_criterion.cuda()
    offset_criterion = offset_criterion.cuda()
    import torch.backends.cudnn as cudnn
def init(batch_size, state, input_sizes, std, mean, dataset, city_aug=0):
    # Return data_loaders
    # depending on whether the state is
    # 1: training
    # 2: just testing

    # Transformations
    # ! Can't use torchvision.Transforms.Compose
    if dataset == 'voc':
        base = base_voc
        workers = 4
        transform_train = Compose([
            ToTensor(),
            RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
            RandomCrop(size=input_sizes[0]),
            RandomHorizontalFlip(flip_prob=0.5),
            Normalize(mean=mean, std=std)
        ])
        transform_test = Compose([
            ToTensor(),
            ZeroPad(size=input_sizes[2]),
            Normalize(mean=mean, std=std)
        ])
    elif dataset == 'city' or dataset == 'gtav' or dataset == 'synthia':  # All the same size
        if dataset == 'city':
            base = base_city
        elif dataset == 'gtav':
            base = base_gtav
        else:
            base = base_synthia
        outlier = False if dataset == 'city' else True  # GTAV has f****d up label ID
        workers = 8

        if city_aug == 3:  # SYNTHIA & GTAV
            if dataset == 'gtav':
                transform_train = Compose([
                    ToTensor(),
                    Resize(size_label=input_sizes[1],
                           size_image=input_sizes[1]),
                    RandomCrop(size=input_sizes[0]),
                    RandomHorizontalFlip(flip_prob=0.5),
                    Normalize(mean=mean, std=std),
                    LabelMap(label_id_map_city, outlier=outlier)
                ])
            else:
                transform_train = Compose([
                    ToTensor(),
                    RandomCrop(size=input_sizes[0]),
                    RandomHorizontalFlip(flip_prob=0.5),
                    Normalize(mean=mean, std=std),
                    LabelMap(label_id_map_synthia, outlier=outlier)
                ])
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
                Normalize(mean=mean, std=std),
                LabelMap(label_id_map_city)
            ])
        elif city_aug == 2:  # ERFNet
            transform_train = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
                LabelMap(label_id_map_city, outlier=outlier),
                RandomTranslation(trans_h=2, trans_w=2),
                RandomHorizontalFlip(flip_prob=0.5)
            ])
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[0], size_label=input_sizes[2]),
                LabelMap(label_id_map_city)
            ])
        elif city_aug == 1:  # City big
            transform_train = Compose([
                ToTensor(),
                RandomCrop(size=input_sizes[0]),
                LabelMap(label_id_map_city, outlier=outlier),
                RandomTranslation(trans_h=2, trans_w=2),
                RandomHorizontalFlip(flip_prob=0.5),
                Normalize(mean=mean, std=std)
            ])
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
                Normalize(mean=mean, std=std),
                LabelMap(label_id_map_city)
            ])
        else:  # Standard city
            transform_train = Compose([
                ToTensor(),
                RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
                RandomCrop(size=input_sizes[0]),
                RandomHorizontalFlip(flip_prob=0.5),
                Normalize(mean=mean, std=std),
                LabelMap(label_id_map_city, outlier=outlier)
            ])
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
                Normalize(mean=mean, std=std),
                LabelMap(label_id_map_city)
            ])
    else:
        raise ValueError

    # Not the actual test set (i.e. validation set)
    test_set = StandardSegmentationDataset(
        root=base_city if dataset == 'gtav' or dataset == 'synthia' else base,
        image_set='val',
        transforms=transform_test,
        data_set='city'
        if dataset == 'gtav' or dataset == 'synthia' else dataset)
    if city_aug == 1:
        val_loader = torch.utils.data.DataLoader(dataset=test_set,
                                                 batch_size=1,
                                                 num_workers=workers,
                                                 shuffle=False)
    else:
        val_loader = torch.utils.data.DataLoader(dataset=test_set,
                                                 batch_size=batch_size,
                                                 num_workers=workers,
                                                 shuffle=False)

    # Testing
    if state == 1:
        return val_loader
    else:
        # Training
        train_set = StandardSegmentationDataset(
            root=base,
            image_set='trainaug' if dataset == 'voc' else 'train',
            transforms=transform_train,
            data_set=dataset)
        train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                                   batch_size=batch_size,
                                                   num_workers=workers,
                                                   shuffle=True)
        return train_loader, val_loader
Esempio n. 14
0
    def __init__(
        self,
        data_src,
        folds2include=None,
        num_folds=5,
        samples_per_epoch=2000,
        roi_size=96,
        scale_int=(0, 255),
        norm_mean=0.,
        norm_sd=1.,
        zoom_range=(0.90, 1.1),
        prob_unseeded_patch=0.2,
        int_aug_offset=None,
        int_aug_expansion=None,
        valid_labels=None,  # if this is None it will include all the available labels
        is_preloaded=False,
        max_input_size=2048  #if any image is larger than this it will be splitted (only working with folds)
    ):

        _dum = set(dir(self))

        self.data_src = Path(data_src)
        if not self.data_src.exists():
            raise ValueError(f'`data_src` : `{data_src}` does not exists.')

        self.folds2include = folds2include
        self.num_folds = num_folds

        self.samples_per_epoch = samples_per_epoch
        self.roi_size = roi_size
        self.scale_int = scale_int

        self.norm_mean = norm_mean
        self.norm_sd = norm_sd

        self.zoom_range = zoom_range
        self.prob_unseeded_patch = prob_unseeded_patch
        self.int_aug_offset = int_aug_offset
        self.int_aug_expansion = int_aug_expansion
        self.valid_labels = valid_labels
        self.is_preloaded = is_preloaded
        self.max_input_size = max_input_size

        self._input_names = list(
            set(dir(self)) - _dum
        )  #i want the name of this fields so i can access them if necessary

        rotation_pad_size = math.ceil(self.roi_size * (math.sqrt(2) - 1) / 2)
        padded_roi_size = roi_size + 2 * rotation_pad_size

        transforms_random = [
            RandomCropWithSeeds(padded_roi_size, rotation_pad_size,
                                prob_unseeded_patch),
            AffineTransform(zoom_range),
            RemovePadding(rotation_pad_size),
            RandomVerticalFlip(),
            RandomHorizontalFlip(),
            NormalizeIntensity(scale_int, norm_mean, norm_sd),
            RandomIntensityOffset(int_aug_offset),
            RandomIntensityExpansion(int_aug_expansion),
            OutContours2Segmask(),
            FixDTypes()
            #I cannot really pass the ToTensor to the dataloader since it crashes when the batchsize is large (>256)
        ]
        self.transforms_random = Compose(transforms_random)

        transforms_full = [
            NormalizeIntensity(scale_int),
            OutContours2Segmask(),
            FixDTypes(),
            ToTensor()
        ]
        self.transforms_full = Compose(transforms_full)
        self.hard_neg_data = None

        if self.data_src.is_dir():
            assert self.folds2include is None
            self.data = self.load_data_from_dir(self.data_src, padded_roi_size,
                                                self.is_preloaded)
        else:
            assert self.is_preloaded
            self.data = self.load_data_from_file(self.data_src)

        self.type_ids = sorted(list(self.data.keys()))
        self.types2label = {k: (ii + 1) for ii, k in enumerate(self.type_ids)}

        self.num_clases = len(self.type_ids)

        #flatten data so i can access the whole list by index
        self.data_indexes = [(_type, _fname, ii)
                             for _type, type_data in self.data.items()
                             for _fname, file_data in type_data.items()
                             for ii in range(len(file_data))]

        assert len(self.data_indexes) > 0  #makes sure there are valid files
Esempio n. 15
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--world_size',
                        type=int,
                        default=1,
                        help='number of GPUs to use')

    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--wd',
                        type=float,
                        default=1e-4,
                        help='weight decay (default: 5e-4)')
    parser.add_argument('--lr-decay-every',
                        type=int,
                        default=100,
                        help='learning rate decay by 10 every X epochs')
    parser.add_argument('--lr-decay-scalar',
                        type=float,
                        default=0.1,
                        help='--')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')

    parser.add_argument('--run_test',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='run test only')

    parser.add_argument(
        '--limit_training_batches',
        type=int,
        default=-1,
        help='how many batches to do per training, -1 means as many as possible'
    )

    parser.add_argument('--no_grad_clip',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='turn off gradient clipping')

    parser.add_argument('--get_flops',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='add hooks to compute flops')

    parser.add_argument(
        '--get_inference_time',
        default=False,
        type=str2bool,
        nargs='?',
        help='runs valid multiple times and reports the result')

    parser.add_argument('--mgpu',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='use data paralization via multiple GPUs')

    parser.add_argument('--dataset',
                        default="MNIST",
                        type=str,
                        help='dataset for experiment, choice: MNIST, CIFAR10')

    parser.add_argument('--data',
                        metavar='DIR',
                        default='/imagenet',
                        help='path to imagenet dataset')

    parser.add_argument(
        '--model',
        default="lenet3",
        type=str,
        help='model selection, choices: lenet3, vgg, mobilenetv2, resnet18',
        choices=[
            "lenet3", "vgg", "mobilenetv2", "resnet18", "resnet152",
            "resnet50", "resnet50_noskip", "resnet20", "resnet34", "resnet101",
            "resnet101_noskip", "densenet201_imagenet", 'densenet121_imagenet',
            "multprun_gate5_gpu_0316_1", "mult_prun8_gpu", "multnas5_gpu"
        ])

    parser.add_argument('--tensorboard',
                        type=str2bool,
                        nargs='?',
                        help='Log progress to TensorBoard')

    parser.add_argument(
        '--save_models',
        default=True,
        type=str2bool,
        nargs='?',
        help='if True, models will be saved to the local folder')

    parser.add_argument('--fineturn_model',
                        type=str2bool,
                        nargs='?',
                        help='Log progress to TensorBoard')

    # ============================PRUNING added
    parser.add_argument(
        '--pruning_config',
        default=None,
        type=str,
        help=
        'path to pruning configuration file, will overwrite all pruning parameters in arguments'
    )

    parser.add_argument('--group_wd_coeff',
                        type=float,
                        default=0.0,
                        help='group weight decay')
    parser.add_argument('--name',
                        default='test',
                        type=str,
                        help='experiment name(folder) to store logs')

    parser.add_argument(
        '--augment',
        default=False,
        type=str2bool,
        nargs='?',
        help=
        'enable or not augmentation of training dataset, only for CIFAR, def False'
    )

    parser.add_argument('--load_model',
                        default='',
                        type=str,
                        help='path to model weights')

    parser.add_argument('--pruning',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='enable or not pruning, def False')

    parser.add_argument(
        '--pruning-threshold',
        '--pt',
        default=100.0,
        type=float,
        help=
        'Max error perc on validation set while pruning (default: 100.0 means always prune)'
    )

    parser.add_argument(
        '--pruning-momentum',
        default=0.0,
        type=float,
        help=
        'Use momentum on criteria between pruning iterations, def 0.0 means no momentum'
    )

    parser.add_argument('--pruning-step',
                        default=15,
                        type=int,
                        help='How often to check loss and do pruning step')

    parser.add_argument('--prune_per_iteration',
                        default=10,
                        type=int,
                        help='How many neurons to remove at each iteration')

    parser.add_argument(
        '--fixed_layer',
        default=-1,
        type=int,
        help='Prune only a given layer with index, use -1 to prune all')

    parser.add_argument('--start_pruning_after_n_iterations',
                        default=0,
                        type=int,
                        help='from which iteration to start pruning')

    parser.add_argument('--maximum_pruning_iterations',
                        default=1e8,
                        type=int,
                        help='maximum pruning iterations')

    parser.add_argument('--starting_neuron',
                        default=0,
                        type=int,
                        help='starting position for oracle pruning')

    parser.add_argument('--prune_neurons_max',
                        default=-1,
                        type=int,
                        help='prune_neurons_max')

    parser.add_argument('--pruning-method',
                        default=0,
                        type=int,
                        help='pruning method to be used, see readme.md')

    parser.add_argument('--pruning_fixed_criteria',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='enable or not criteria reevaluation, def False')

    parser.add_argument('--fixed_network',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='fix network for oracle or criteria computation')

    parser.add_argument(
        '--zero_lr_for_epochs',
        default=-1,
        type=int,
        help='Learning rate will be set to 0 for given number of updates')

    parser.add_argument(
        '--dynamic_network',
        default=False,
        type=str2bool,
        nargs='?',
        help=
        'Creates a new network graph from pruned model, works with ResNet-101 only'
    )

    parser.add_argument('--use_test_as_train',
                        default=False,
                        type=str2bool,
                        nargs='?',
                        help='use testing dataset instead of training')

    parser.add_argument('--pruning_mask_from',
                        default='',
                        type=str,
                        help='path to mask file precomputed')

    parser.add_argument(
        '--compute_flops',
        default=True,
        type=str2bool,
        nargs='?',
        help=
        'if True, will run dummy inference of batch 1 before training to get conv sizes'
    )

    # ============================END pruning added

    best_prec1 = 0
    global global_iteration
    global group_wd_optimizer
    global_iteration = 0

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    args.distributed = args.world_size > 1
    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=0)

    device = torch.device("cuda" if use_cuda else "cpu")

    # dataset loading section
    if args.dataset == "MNIST":
        kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
        train_loader = torch.utils.data.DataLoader(datasets.MNIST(
            '../data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307, ), (0.3081, ))
            ])),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   **kwargs)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data',
                           train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ])),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)

    elif args.dataset == "CIFAR10":
        # Data loading code
        normalize = transforms.Normalize(
            mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
            std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

        if args.augment:
            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ])
        else:
            transform_train = transforms.Compose([
                transforms.ToTensor(),
                normalize,
            ])

        transform_test = transforms.Compose([transforms.ToTensor(), normalize])

        kwargs = {'num_workers': 8, 'pin_memory': True}
        train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(
            '../data', train=True, download=True, transform=transform_train),
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   drop_last=True,
                                                   **kwargs)

        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('../data', train=False, transform=transform_test),
            batch_size=args.test_batch_size,
            shuffle=True,
            **kwargs)

    elif args.dataset == "Imagenet":
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        if args.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset)
        else:
            train_sampler = None

        kwargs = {'num_workers': 16}

        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=(train_sampler is None),
            sampler=train_sampler,
            pin_memory=True,
            **kwargs)

        if args.use_test_as_train:
            train_loader = torch.utils.data.DataLoader(
                datasets.ImageFolder(
                    valdir,
                    transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        normalize,
                    ])),
                batch_size=args.batch_size,
                shuffle=(train_sampler is None),
                **kwargs)

        test_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
                                                  batch_size=args.batch_size,
                                                  shuffle=False,
                                                  pin_memory=True,
                                                  **kwargs)
    #wm
    elif args.dataset == "mult_5T":
        args.data_root = ['/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/CX_20200709',\
        '/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/TK_20200709',\
        '/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/ZR_20200709',\
        '/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/TX_20200616',\
        '/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/WM_20200709']

        args.data_root_val = ['/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/CX_20200709',\
        '/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/TK_20200709',\
        '/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/ZR_20200709',\
        '/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/TX_20200616',\
        '/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/WM_20200709']

        args.train_data_list = ['/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/CX_20200709/txt/cx_train.txt',\
        '/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/TK_20200709/txt/tk_train.txt',\
        '/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/ZR_20200709/txt/zr_train.txt',\
        '/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/TX_20200616/txt/tx_train.txt',\
        '/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/WM_20200709/txt/wm_train.txt']

        args.val_data_list = ['/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/CX_20200709/txt/cx_val.txt',\
        '/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/TK_20200709/txt/tk_val.txt',\
        '/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/ZR_20200709/txt/zr_val.txt',\
        '/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/TX_20200616/txt/tx_val.txt',\
        '/workspace/mnt/storage/yangdecheng/yangdecheng/data/TR-NMA-07/WM_20200709/txt/wm_val.txt']

        num_tasks = len(args.data_root)
        args.ngpu = 8
        args.workers = 8
        args.train_batch_size = [40, 40, 40, 40, 40]  #36
        args.val_batch_size = [100, 100, 100, 100, 100]
        args.loss_weight = [1.0, 1.0, 1.0, 1.0, 1.0]
        args.val_num_classes = [[0, 1, 2, 3, 4], [0, 1, 2], [0, 1], [0, 1],
                                [0, 1, 2, 3, 4, 5, 6]]
        args.mixup_alpha = None  #None

        for i in range(num_tasks):
            args.train_batch_size[i] *= args.ngpu
            args.val_batch_size[i] *= args.ngpu

        pixel_mean = [0.406, 0.456, 0.485]
        pixel_std = [0.225, 0.224, 0.229]

        #私人定制:
        train_dataset = []
        for i in range(num_tasks):
            if i == 1:
                train_dataset.append(
                    FileListLabeledDataset(
                        args.train_data_list[i],
                        args.data_root[i],
                        Compose([
                            RandomResizedCrop(
                                112,
                                scale=(0.94, 1.),
                                ratio=(1. / 4., 4. / 1.)
                            ),  #scale=(0.7, 1.2), ratio=(1. / 1., 4. / 1.)
                            RandomHorizontalFlip(),
                            ColorJitter(brightness=[0.5, 1.5],
                                        contrast=[0.5, 1.5],
                                        saturation=[0.5, 1.5],
                                        hue=0),
                            ToTensor(),
                            Lighting(1, [0.2175, 0.0188, 0.0045],
                                     [[-0.5675, 0.7192, 0.4009],
                                      [-0.5808, -0.0045, -0.8140],
                                      [-0.5836, -0.6948, 0.4203]]),
                            Normalize(pixel_mean, pixel_std),
                        ])))
            else:
                train_dataset.append(
                    FileListLabeledDataset(
                        args.train_data_list[i], args.data_root[i],
                        Compose([
                            RandomResizedCrop(112,
                                              scale=(0.7, 1.2),
                                              ratio=(1. / 1., 4. / 1.)),
                            RandomHorizontalFlip(),
                            ColorJitter(brightness=[0.5, 1.5],
                                        contrast=[0.5, 1.5],
                                        saturation=[0.5, 1.5],
                                        hue=0),
                            ToTensor(),
                            Lighting(1, [0.2175, 0.0188, 0.0045],
                                     [[-0.5675, 0.7192, 0.4009],
                                      [-0.5808, -0.0045, -0.8140],
                                      [-0.5836, -0.6948, 0.4203]]),
                            Normalize(pixel_mean, pixel_std),
                        ])))
        #原来的
        # train_dataset  = [FileListLabeledDataset(
        # args.train_data_list[i], args.data_root[i],
        # Compose([
        #     RandomResizedCrop(112,scale=(0.7, 1.2), ratio=(1. / 1., 4. / 1.)),
        #     RandomHorizontalFlip(),
        #     ColorJitter(brightness=[0.5,1.5], contrast=[0.5,1.5], saturation=[0.5,1.5], hue= 0),
        #     ToTensor(),
        #     Lighting(1, [0.2175, 0.0188, 0.0045], [[-0.5675,  0.7192,  0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948,  0.4203]]),
        #     Normalize(pixel_mean, pixel_std),]),
        # memcached=False,
        # memcached_client="") for i in range(num_tasks)]

        args.num_classes = [td.num_class for td in train_dataset]
        train_longest_size = max([
            int(np.ceil(len(td) / float(bs)))
            for td, bs in zip(train_dataset, args.train_batch_size)
        ])
        train_sampler = [
            GivenSizeSampler(td,
                             total_size=train_longest_size * bs,
                             rand_seed=0)
            for td, bs in zip(train_dataset, args.train_batch_size)
        ]
        train_loader = [
            DataLoader(train_dataset[k],
                       batch_size=args.train_batch_size[k],
                       shuffle=False,
                       num_workers=args.workers,
                       pin_memory=False,
                       sampler=train_sampler[k]) for k in range(num_tasks)
        ]

        val_dataset = [
            FileListLabeledDataset(
                args.val_data_list[i],
                args.data_root_val[i],
                Compose([
                    Resize((112, 112)),
                    # CenterCrop(112),
                    ToTensor(),
                    Normalize(pixel_mean, pixel_std),
                ]),
                memcached=False,
                memcached_client="") for i in range(num_tasks)
        ]
        val_longest_size = max([
            int(np.ceil(len(vd) / float(bs)))
            for vd, bs in zip(val_dataset, args.val_batch_size)
        ])
        test_loader = [
            DataLoader(val_dataset[k],
                       batch_size=args.val_batch_size[k],
                       shuffle=False,
                       num_workers=args.workers,
                       pin_memory=False) for k in range(num_tasks)
        ]

    if args.model == "lenet3":
        model = LeNet(dataset=args.dataset)
    elif args.model == "vgg":
        model = vgg11_bn(pretrained=True)
    elif args.model == "resnet18":
        model = PreActResNet18()
    elif (args.model == "resnet50") or (args.model == "resnet50_noskip"):
        if args.dataset == "CIFAR10":
            model = PreActResNet50(dataset=args.dataset)
        else:
            from models.resnet import resnet50
            skip_gate = True
            if "noskip" in args.model:
                skip_gate = False

            if args.pruning_method not in [22, 40]:
                skip_gate = False
            model = resnet50(skip_gate=skip_gate)
    elif args.model == "resnet34":
        if not (args.dataset == "CIFAR10"):
            from models.resnet import resnet34
            model = resnet34()
    elif args.model == "multprun_gate5_gpu_0316_1":
        from models.multitask import MultiTaskWithLoss
        model = MultiTaskWithLoss(backbone=args.model,
                                  num_classes=args.num_classes,
                                  feature_dim=2560,
                                  spatial_size=112,
                                  arc_fc=False,
                                  feat_bn=False)
        print(model)
    elif args.model == "mult_prun8_gpu":
        from models.multitask import MultiTaskWithLoss
        model = MultiTaskWithLoss(backbone=args.model,
                                  num_classes=args.num_classes,
                                  feature_dim=18,
                                  spatial_size=112,
                                  arc_fc=False,
                                  feat_bn=False)
        print(model)
    elif args.model == "multnas5_gpu":  #作为修改项
        from models.multitask import MultiTaskWithLoss
        model = MultiTaskWithLoss(backbone=args.model,
                                  num_classes=args.num_classes,
                                  feature_dim=512,
                                  spatial_size=112,
                                  arc_fc=False,
                                  feat_bn=False)
        print(model)
    elif "resnet101" in args.model:
        if not (args.dataset == "CIFAR10"):
            from models.resnet import resnet101
            if args.dataset == "Imagenet":
                classes = 1000

            if "noskip" in args.model:
                model = resnet101(num_classes=classes, skip_gate=False)
            else:
                model = resnet101(num_classes=classes)

    elif args.model == "resnet20":
        if args.dataset == "CIFAR10":
            NotImplementedError(
                "resnet20 is not implemented in the current project")
            # from models.resnet_cifar import resnet20
            # model = resnet20()
    elif args.model == "resnet152":
        model = PreActResNet152()
    elif args.model == "densenet201_imagenet":
        from models.densenet_imagenet import DenseNet201
        model = DenseNet201(gate_types=['output_bn'], pretrained=True)
    elif args.model == "densenet121_imagenet":
        from models.densenet_imagenet import DenseNet121
        model = DenseNet121(gate_types=['output_bn'], pretrained=True)
    else:
        print(args.model, "model is not supported")

    ####end dataset preparation

    if args.dynamic_network:
        # attempts to load pruned model and modify it be removing pruned channels
        # works for resnet101 only
        if (len(args.load_model) > 0) and (args.dynamic_network):
            if os.path.isfile(args.load_model):
                load_model_pytorch(model, args.load_model, args.model)

            else:
                print("=> no checkpoint found at '{}'".format(args.load_model))
                exit()

        dynamic_network_change_local(model)

        # save the model
        log_save_folder = "%s" % args.name
        if not os.path.exists(log_save_folder):
            os.makedirs(log_save_folder)

        if not os.path.exists("%s/models" % (log_save_folder)):
            os.makedirs("%s/models" % (log_save_folder))

        model_save_path = "%s/models/pruned.weights" % (log_save_folder)
        model_state_dict = model.state_dict()
        if args.save_models:
            save_checkpoint({'state_dict': model_state_dict},
                            False,
                            filename=model_save_path)

    print("model is defined")

    # aux function to get size of feature maps
    # First it adds hooks for each conv layer
    # Then runs inference with 1 image
    output_sizes = get_conv_sizes(args, model)

    if use_cuda and not args.mgpu:
        model = model.to(device)
    elif args.distributed:
        model.cuda()
        print(
            "\n\n WARNING: distributed pruning was not verified and might not work correctly"
        )
        model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.mgpu:
        model = torch.nn.DataParallel(model).cuda()
    else:
        model = model.to(device)

    print(
        "model is set to device: use_cuda {}, args.mgpu {}, agrs.distributed {}"
        .format(use_cuda, args.mgpu, args.distributed))

    weight_decay = args.wd
    if args.fixed_network:
        weight_decay = 0.0

    # remove updates from gate layers, because we want them to be 0 or 1 constantly
    if 1:
        parameters_for_update = []
        parameters_for_update_named = []
        for name, m in model.named_parameters():
            if "gate" not in name:
                parameters_for_update.append(m)
                parameters_for_update_named.append((name, m))
            else:
                print("skipping parameter", name, "shape:", m.shape)

    total_size_params = sum(
        [np.prod(par.shape) for par in parameters_for_update])
    print("Total number of parameters, w/o usage of bn consts: ",
          total_size_params)

    optimizer = optim.SGD(parameters_for_update,
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=weight_decay)

    if 1:
        # helping optimizer to implement group lasso (with very small weight that doesn't affect training)
        # will be used to calculate number of remaining flops and parameters in the network
        group_wd_optimizer = group_lasso_decay(
            parameters_for_update,
            group_lasso_weight=args.group_wd_coeff,
            named_parameters=parameters_for_update_named,
            output_sizes=output_sizes)

    cudnn.benchmark = True

    # define objective
    criterion = nn.CrossEntropyLoss()

    ###=======================added for pruning
    # logging part
    log_save_folder = "%s" % args.name
    if not os.path.exists(log_save_folder):
        os.makedirs(log_save_folder)

    if not os.path.exists("%s/models" % (log_save_folder)):
        os.makedirs("%s/models" % (log_save_folder))

    train_writer = None
    if args.tensorboard:
        try:
            # tensorboardX v1.6
            train_writer = SummaryWriter(log_dir="%s" % (log_save_folder))
        except:
            # tensorboardX v1.7
            train_writer = SummaryWriter(logdir="%s" % (log_save_folder))

    time_point = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
    textfile = "%s/log_%s.txt" % (log_save_folder, time_point)
    stdout = Logger(textfile)
    sys.stdout = stdout
    print(" ".join(sys.argv))

    # initializing parameters for pruning
    # we can add weights of different layers or we can add gates (multiplies output with 1, useful only for gradient computation)
    pruning_engine = None
    if args.pruning:
        pruning_settings = dict()
        if not (args.pruning_config is None):
            pruning_settings_reader = PruningConfigReader()
            pruning_settings_reader.read_config(args.pruning_config)
            pruning_settings = pruning_settings_reader.get_parameters()

        # overwrite parameters from config file with those from command line
        # needs manual entry here
        # user_specified = [key for key in vars(default_args).keys() if not (vars(default_args)[key]==vars(args)[key])]
        # argv_of_interest = ['pruning_threshold', 'pruning-momentum', 'pruning_step', 'prune_per_iteration',
        #                     'fixed_layer', 'start_pruning_after_n_iterations', 'maximum_pruning_iterations',
        #                     'starting_neuron', 'prune_neurons_max', 'pruning_method']

        has_attribute = lambda x: any([x in a for a in sys.argv])

        if has_attribute('pruning-momentum'):
            pruning_settings['pruning_momentum'] = vars(
                args)['pruning_momentum']
        if has_attribute('pruning-method'):
            pruning_settings['method'] = vars(args)['pruning_method']

        pruning_parameters_list = prepare_pruning_list(
            pruning_settings,
            model,
            model_name=args.model,
            pruning_mask_from=args.pruning_mask_from,
            name=args.name)
        print("Total pruning layers:", len(pruning_parameters_list))

        folder_to_write = "%s" % log_save_folder + "/"
        log_folder = folder_to_write

        pruning_engine = pytorch_pruning(pruning_parameters_list,
                                         pruning_settings=pruning_settings,
                                         log_folder=log_folder)

        pruning_engine.connect_tensorboard(train_writer)
        pruning_engine.dataset = args.dataset
        pruning_engine.model = args.model
        pruning_engine.pruning_mask_from = args.pruning_mask_from
        pruning_engine.load_mask()
        gates_to_params = connect_gates_with_parameters_for_flops(
            args.model, parameters_for_update_named)
        pruning_engine.gates_to_params = gates_to_params

    ###=======================end for pruning
    # loading model file
    if (len(args.load_model) > 0) and (not args.dynamic_network):
        if os.path.isfile(args.load_model):
            if args.fineturn_model:
                checkpoint = torch.load(args.load_model)
                state_dict = checkpoint['state_dict']
                model = load_module_state_dict_checkpoint(model, state_dict)
            else:
                load_model_pytorch(model, args.load_model, args.model)
        else:
            print("=> no checkpoint found at '{}'".format(args.load_model))
            exit()

    if args.tensorboard and 0:
        if args.dataset == "CIFAR10":
            dummy_input = torch.rand(1, 3, 32, 32).to(device)
        elif args.dataset == "Imagenet":
            dummy_input = torch.rand(1, 3, 224, 224).to(device)

        train_writer.add_graph(model, dummy_input)

    for epoch in range(1, args.epochs + 1):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(args, optimizer, epoch, args.zero_lr_for_epochs,
                             train_writer)

        if not args.run_test and not args.get_inference_time:
            train(args,
                  model,
                  device,
                  train_loader,
                  optimizer,
                  epoch,
                  criterion,
                  train_writer=train_writer,
                  pruning_engine=pruning_engine)

        if args.pruning:
            # skip validation error calculation and model saving
            if pruning_engine.method == 50: continue

        # evaluate on validation set
        prec1 = validate(args,
                         test_loader,
                         model,
                         device,
                         criterion,
                         epoch,
                         train_writer=train_writer)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        model_save_path = "%s/models/checkpoint.weights" % (log_save_folder)
        paths = "%s/models" % (log_save_folder)
        model_state_dict = model.state_dict()
        if args.save_models:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model_state_dict,
                    'best_prec1': best_prec1,
                },
                is_best,
                filename=model_save_path)
            states = {
                'epoch': epoch + 1,
                'state_dict': model_state_dict,
            }
            torch.save(states, '{}/{}.pth.tar'.format(paths, epoch + 1))
def main(args):
    # keep shuffling be constant every time
    seed = 1
    torch.manual_seed(seed)
    # torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.

    # norm_method = Normalize(args.mean, args.std)
    norm_method = Normalize([0, 0, 0], [1, 1, 1])
    # scales = [0.5, 0.6, 0.7, 0.8, 0.9]
    trans_train = Compose([
        # Scale(args.img_size),
        # MultiScaleRandomCrop(scales, args.img_size),
        # CenterCrop(args.img_size),
        RandomHorizontalFlip(),
        ToTensor(1),
        norm_method
    ])
    trans_test = Compose([
        # Scale(args.img_size),
        # CenterCrop(args.img_size),
        ToTensor(1),
        norm_method
    ])

    if args.is_train:
        # dataset_train = dataset_R3D(
        #     root_dir, 'train', args.data_type, args.n_frames_per_clip,
        #     img_size=args.img_size, stride=1, overlap=True, reverse=True, transform=trans_train)
        # dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size,
        #                             shuffle=True,num_workers=12,pin_memory=True)

        # dataset_val = dataset_R3D(
        #     root_dir, 'val', args.data_type, args.n_frames_per_clip,
        #     img_size=args.img_size, stride=1, overlap=False, transform=trans_test)
        # sampler = RandomSampler(dataset_val, replacement=True, num_samples=args.batch_size*5)
        # dataloader_val = DataLoader(dataset_val, batch_size=16,
        #                             num_workers=8,sampler=sampler,
        #                             pin_memory=True)
        # dataloader for phase 2
        mask_trans = transforms.Compose([
            # transforms.Resize((126,224)),
            # transforms.CenterCrop((126,224)),
            transforms.ToTensor()
        ])
        print('Loading phase2 training data.....')
        dataset_train = dataset_unequal.dataset_all(
            root_dir,
            'train',
            n_frames_per_clip=args.n_frames_per_clip,
            UnequalSequence=True,
            img_size=(args.w, args.h),
            stride=2,
            reverse=False,
            transform=trans_train,
            mask_trans=mask_trans)
        dataloader_train = DataLoader(dataset_train,
                                      batch_size=128,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

        print('\n')
        print('Loading phase2 validating data.....')
        dataset_val = dataset_unequal.dataset_all(
            root_dir,
            'val',
            n_frames_per_clip=args.n_frames_per_clip,
            img_size=(args.w, args.h),
            stride=2,
            UnequalSequence=True,
            reverse=False,
            transform=trans_test,
            mask_trans=mask_trans)
        sampler = RandomSampler(dataset_val,
                                replacement=True,
                                num_samples=1024)
        dataloader_val = DataLoader(dataset_val,
                                    batch_size=64,
                                    sampler=sampler,
                                    num_workers=args.num_workers,
                                    pin_memory=True)
    else:
        dataset_test = dataset_R3D(root_dir,
                                   'test',
                                   args.data_type,
                                   args.n_frames_per_clip,
                                   img_size=args.img_size,
                                   stride=args.n_frames_per_clip,
                                   overlap=False,
                                   transform=trans_test)
        dataloader_test = DataLoader(dataset_test,
                                     batch_size=128,
                                     shuffle=True,
                                     num_workers=8,
                                     pin_memory=True)
    if args.is_train:
        fine_tune(model_dir,
                  save_dir,
                  'resnet-152-kinetics.pth',
                  dataloader_train,
                  dataloader_val,
                  ContinuousTrain=False)
        pdb.set_trace()
    else:
        model_test('output', 'checkpoint-3.pth', dataloader_test)
Esempio n. 17
0
def main(args):

    print(args)

    torch.backends.cudnn.benchmark = True

    # Data loading
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    base_size = 320
    crop_size = 256
    min_size, max_size = int(0.5 * base_size), int(2.0 * base_size)

    train_loader, val_loader = None, None
    if not args.test_only:
        st = time.time()
        train_set = VOCSegmentation(args.data_path,
                                    image_set='train',
                                    download=True,
                                    transforms=Compose([
                                        RandomResize(min_size, max_size),
                                        RandomCrop(crop_size),
                                        RandomHorizontalFlip(0.5),
                                        ImageTransform(
                                            T.ColorJitter(brightness=0.3,
                                                          contrast=0.3,
                                                          saturation=0.1,
                                                          hue=0.02)),
                                        ToTensor(),
                                        ImageTransform(normalize)
                                    ]))

        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            drop_last=True,
            sampler=RandomSampler(train_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(f"Training set loaded in {time.time() - st:.2f}s "
              f"({len(train_set)} samples in {len(train_loader)} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target, ignore_index=255)
        return

    if not (args.lr_finder or args.check_setup):
        st = time.time()
        val_set = VOCSegmentation(args.data_path,
                                  image_set='val',
                                  download=True,
                                  transforms=Compose([
                                      Resize((crop_size, crop_size)),
                                      ToTensor(),
                                      ImageTransform(normalize)
                                  ]))

        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=args.batch_size,
            drop_last=False,
            sampler=SequentialSampler(val_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(
            f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)"
        )

    model = segmentation.__dict__[args.model](
        args.pretrained,
        not (args.pretrained),
        num_classes=len(VOC_CLASSES),
    )

    # Loss setup
    loss_weight = None
    if isinstance(args.bg_factor, float):
        loss_weight = torch.ones(len(VOC_CLASSES))
        loss_weight[0] = args.bg_factor
    if args.loss == 'crossentropy':
        criterion = nn.CrossEntropyLoss(weight=loss_weight, ignore_index=255)
    elif args.loss == 'label_smoothing':
        criterion = holocron.nn.LabelSmoothingCrossEntropy(weight=loss_weight,
                                                           ignore_index=255)
    elif args.loss == 'focal':
        criterion = holocron.nn.FocalLoss(weight=loss_weight, ignore_index=255)
    elif args.loss == 'mc':
        criterion = holocron.nn.MutualChannelLoss(weight=loss_weight,
                                                  ignore_index=255)

    # Optimizer setup
    model_params = [p for p in model.parameters() if p.requires_grad]
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model_params,
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model_params,
                                     args.lr,
                                     betas=(0.95, 0.99),
                                     eps=1e-6,
                                     weight_decay=args.weight_decay)
    elif args.opt == 'radam':
        optimizer = holocron.optim.RAdam(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adamp':
        optimizer = holocron.optim.AdamP(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adabelief':
        optimizer = holocron.optim.AdaBelief(model_params,
                                             args.lr,
                                             betas=(0.95, 0.99),
                                             eps=1e-6,
                                             weight_decay=args.weight_decay)

    trainer = SegmentationTrainer(model,
                                  train_loader,
                                  val_loader,
                                  criterion,
                                  optimizer,
                                  args.device,
                                  args.output_file,
                                  num_classes=len(VOC_CLASSES))
    if args.resume:
        print(f"Resuming {args.resume}")
        checkpoint = torch.load(args.resume, map_location='cpu')
        trainer.load(checkpoint)

    if args.show_preds:
        x, target = next(iter(train_loader))
        with torch.no_grad():
            if isinstance(args.device, int):
                x = x.cuda()
            trainer.model.eval()
            preds = trainer.model(x)
        plot_predictions(x.cpu(), preds.cpu(), target, ignore_index=255)
        return

    if args.test_only:
        print("Running evaluation")
        eval_metrics = trainer.evaluate()
        print(
            f"Validation loss: {eval_metrics['val_loss']:.4} (Mean IoU: {eval_metrics['mean_iou']:.2%})"
        )
        return

    if args.lr_finder:
        print("Looking for optimal LR")
        trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100))
        trainer.plot_recorder()
        return

    if args.check_setup:
        print("Checking batch overfitting")
        is_ok = trainer.check_setup(args.freeze_until,
                                    args.lr,
                                    num_it=min(len(train_loader), 100))
        print(is_ok)
        return

    print("Start training")
    start_time = time.time()
    trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched)
    total_time_str = str(
        datetime.timedelta(seconds=int(time.time() - start_time)))
    print(f"Training time {total_time_str}")
Esempio n. 18
0
def init(batch_size, state, split, input_sizes, sets_id, std, mean, keep_scale, reverse_channels, data_set,
         valtiny, no_aug):
    # Return data_loaders/data_loader
    # depending on whether the split is
    # 1: semi-supervised training
    # 2: fully-supervised training
    # 3: Just testing

    # Transformations (compatible with unlabeled data/pseudo labeled data)
    # ! Can't use torchvision.Transforms.Compose
    if data_set == 'voc':
        base = base_voc
        workers = 4
        transform_train = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
             RandomCrop(size=input_sizes[0]),
             RandomHorizontalFlip(flip_prob=0.5),
             Normalize(mean=mean, std=std)])
        if no_aug:
            transform_train_pseudo = Compose(
                [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
                 Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
                 Normalize(mean=mean, std=std)])
        else:
            transform_train_pseudo = Compose(
                [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
                 RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
                 RandomCrop(size=input_sizes[0]),
                 RandomHorizontalFlip(flip_prob=0.5),
                 Normalize(mean=mean, std=std)])
        transform_pseudo = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
             Normalize(mean=mean, std=std)])
        transform_test = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             ZeroPad(size=input_sizes[2]),
             Normalize(mean=mean, std=std)])
    elif data_set == 'city':  # All the same size (whole set is down-sampled by 2)
        base = base_city
        workers = 8
        transform_train = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
             RandomCrop(size=input_sizes[0]),
             RandomHorizontalFlip(flip_prob=0.5),
             Normalize(mean=mean, std=std),
             LabelMap(label_id_map_city)])
        if no_aug:
            transform_train_pseudo = Compose(
                [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
                 Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
                 Normalize(mean=mean, std=std)])
        else:
            transform_train_pseudo = Compose(
                [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
                 RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
                 RandomCrop(size=input_sizes[0]),
                 RandomHorizontalFlip(flip_prob=0.5),
                 Normalize(mean=mean, std=std)])
        transform_pseudo = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
             Normalize(mean=mean, std=std),
             LabelMap(label_id_map_city)])
        transform_test = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
             Normalize(mean=mean, std=std),
             LabelMap(label_id_map_city)])
    else:
        base = ''

    # Not the actual test set (i.e.validation set)
    test_set = StandardSegmentationDataset(root=base, image_set='valtiny' if valtiny else 'val',
                                           transforms=transform_test, label_state=0, data_set=data_set)
    val_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size, num_workers=workers, shuffle=False)

    # Testing
    if state == 3:
        return val_loader
    else:
        # Fully-supervised training
        if state == 2:
            labeled_set = StandardSegmentationDataset(root=base, image_set=(str(split) + '_labeled_' + str(sets_id)),
                                                      transforms=transform_train, label_state=0, data_set=data_set)
            labeled_loader = torch.utils.data.DataLoader(dataset=labeled_set, batch_size=batch_size,
                                                         num_workers=workers, shuffle=True)
            return labeled_loader, val_loader

        # Semi-supervised training
        elif state == 1:
            pseudo_labeled_set = StandardSegmentationDataset(root=base, data_set=data_set,
                                                             image_set=(str(split) + '_unlabeled_' + str(sets_id)),
                                                             transforms=transform_train_pseudo, label_state=1)
            reference_set = SegmentationLabelsDataset(root=base, image_set=(str(split) + '_unlabeled_' + str(sets_id)),
                                                      data_set=data_set)
            reference_loader = torch.utils.data.DataLoader(dataset=reference_set, batch_size=batch_size,
                                                           num_workers=workers, shuffle=False)
            unlabeled_set = StandardSegmentationDataset(root=base, data_set=data_set,
                                                        image_set=(str(split) + '_unlabeled_' + str(sets_id)),
                                                        transforms=transform_pseudo, label_state=2)
            labeled_set = StandardSegmentationDataset(root=base, data_set=data_set,
                                                      image_set=(str(split) + '_labeled_' + str(sets_id)),
                                                      transforms=transform_train, label_state=0)

            unlabeled_loader = torch.utils.data.DataLoader(dataset=unlabeled_set, batch_size=batch_size,
                                                           num_workers=workers, shuffle=False)

            pseudo_labeled_loader = torch.utils.data.DataLoader(dataset=pseudo_labeled_set,
                                                                batch_size=int(batch_size / 2),
                                                                num_workers=workers, shuffle=True)
            labeled_loader = torch.utils.data.DataLoader(dataset=labeled_set,
                                                         batch_size=int(batch_size / 2),
                                                         num_workers=workers, shuffle=True)
            return labeled_loader, pseudo_labeled_loader, unlabeled_loader, val_loader, reference_loader

        else:
            # Support unsupervised learning here if that's what you want
            raise ValueError
Esempio n. 19
0
def main():
    with open('config.yaml', 'r') as stream:  # Load YAML configuration file.
        config = yaml.safe_load(stream)

    model_params = config['model']
    train_params = config['train']
    val_params = config['val']

    # Defining model:
    set_seed(0)
    model = RDUNet(**model_params)

    print('Model summary:')
    test_shape = (model_params['channels'], train_params['patch size'],
                  train_params['patch size'])
    with torch.no_grad():
        macs, params = get_model_complexity_info(model,
                                                 test_shape,
                                                 as_strings=True,
                                                 print_per_layer_stat=False,
                                                 verbose=False)
        print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
        print('{:<30}  {:<8}'.format('Number of parameters: ', params))

    # Define the model name and use multi-GPU if it is allowed.
    model_name = 'model_color' if model_params[
        'channels'] == 3 else 'model_gray'
    device = torch.device(train_params['device'])
    print("Using device: {}".format(device))
    if torch.cuda.device_count(
    ) > 1 and 'cuda' in device.type and train_params['multi gpu']:
        model = nn.DataParallel(model)
        print('Using multiple GPUs')

    model = model.to(device)
    param_group = []
    for name, param in model.named_parameters():
        if 'conv' in name and 'weight' in name:
            p = {'params': param, 'weight_decay': train_params['weight decay']}
        else:
            p = {'params': param, 'weight_decay': 0.}
        param_group.append(p)

    # Load training and validation file names.
    # Modify .txt files if datasets do not fit in memory.
    with open('train_files.txt', 'r') as f_train, open('val_files.txt',
                                                       'r') as f_val:
        raw_train_files = f_train.read().splitlines()
        raw_val_files = f_val.read().splitlines()
        train_files = list(
            map(lambda file: join(train_params['dataset path'], file),
                raw_train_files))
        val_files = list(
            map(lambda file: join(val_params['dataset path'], file),
                raw_val_files))

    training_transforms = transforms.Compose(
        [RandomHorizontalFlip(),
         RandomVerticalFlip(),
         RandomRot90()])

    # Predefined noise level
    train_noise_transform = [
        AdditiveWhiteGaussianNoise(train_params['noise level'], clip=True)
    ]
    val_noise_transforms = [
        AdditiveWhiteGaussianNoise(s, fix_sigma=True, clip=True)
        for s in val_params['noise levels']
    ]

    print('\nLoading training dataset:')
    training_dataset = NoisyImagesDataset(train_files,
                                          model_params['channels'],
                                          train_params['patch size'],
                                          training_transforms,
                                          train_noise_transform)

    print('\nLoading validation dataset:')
    validation_dataset = NoisyImagesDataset(val_files,
                                            model_params['channels'],
                                            val_params['patch size'], None,
                                            val_noise_transforms)
    # Training in sub-epochs:
    print('Training patches:', len(training_dataset))
    print('Validation patches:', len(validation_dataset))
    n_samples = len(training_dataset) // train_params['dataset splits']
    n_epochs = train_params['epochs'] * train_params['dataset splits']
    sampler = DataSampler(training_dataset, num_samples=n_samples)

    data_loaders = {
        'train':
        DataLoader(training_dataset,
                   train_params['batch size'],
                   num_workers=train_params['workers'],
                   sampler=sampler),
        'val':
        DataLoader(validation_dataset,
                   val_params['batch size'],
                   num_workers=val_params['workers']),
    }

    # Optimization:
    learning_rate = train_params['learning rate']
    step_size = train_params['scheduler step'] * train_params['dataset splits']

    criterion = nn.L1Loss()
    optimizer = optim.AdamW(param_group, lr=learning_rate)
    lr_scheduler = optim.lr_scheduler.StepLR(
        optimizer, step_size=step_size, gamma=train_params['scheduler gamma'])

    # Train the model
    fit_model(model, data_loaders, model_params['channels'], criterion,
              optimizer, lr_scheduler, device, n_epochs,
              val_params['frequency'], train_params['checkpoint path'],
              model_name)