Beispiel #1
0
    def __init__(self, content=None):

        get = content.get_resource

        self.__model = onnxruntime.InferenceSession(
            get(content.model_path))

        self.line_transform = Compose([
            Resize((512, 512)),
            ToTensor(),
            Normalize([0.5], [0.5]),
            Lambda(lambda img: np.expand_dims(img, 0))
        ])
        self.hint_transform = Compose([
            # input must RGBA !
            Resize((128, 128), Image.NEAREST),
            Lambda(lambda img: img.convert(mode='RGB')),
            ToTensor(),
            Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            Lambda(lambda img: np.expand_dims(img, 0))
        ])
        self.line_draft_transform = Compose([
            Resize((128, 128)),
            ToTensor(),
            Normalize([0.5], [0.5]),
            Lambda(lambda img: np.expand_dims(img, 0))
        ])
        self.alpha_transform = Compose([
            Lambda(lambda img: self.get_alpha(img)),
        ])
Beispiel #2
0
def valid_dataset(root_dir,
                  normalization=None,
                  grayscale=False,
                  square=False,
                  csv_file='B/validation_data.csv',
                  scale=1.0):
    """This function loads the training dataset with the desired transformations."""

    scaled_width = int(round(scale * constants.default_width))

    transformations = []
    if scale != 1.0:
        transformations.append(Resize(scaled_width))

    if square:
        transformations.append(Square())

    if grayscale:
        transformations.append(BlackAndWhite())

    transformations.append(ToTensor())

    if normalization is not None:
        transformations.append(normalization)

    transform = transforms.Compose(transformations)

    valid_dataset = PostureLandmarksDataset(csv_file=csv_file,
                                            root_dir=root_dir,
                                            transform=transform)
    return valid_dataset
Beispiel #3
0
def preprocess(img_path, size):
    # 图像变换
    transform = Compose([
        Resize(
            size,
            size,
            resize_target=None,
            keep_aspect_ratio=False,
            ensure_multiple_of=32,
            resize_method="upper_bound",
            image_interpolation_method=cv2.INTER_CUBIC,
        ),
        NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        PrepareForNet()
    ])

    # 读取图片
    img = cv2.imread(img_path)

    # 归一化
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0

    # 图像变换
    img = transform({"image": img})["image"]

    # 新增维度
    img = img[np.newaxis, ...]

    return img
Beispiel #4
0
def get_transforms():
    data_transform = torchvision.transforms.Compose([
        ToTensor(),
        Normalize(mean=constants.DATA_MEAN, std=constants.DATA_STD),
        Resize(constants.TRANSFORMED_IMAGE_SIZE)
    ])
    return data_transform
Beispiel #5
0
def load_data(datadir, img_size=416, crop_pct=0.875):
    # Data loading code
    print("Loading data")
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    scale_size = int(math.floor(img_size / crop_pct))

    print("Loading training data")
    st = time.time()
    dataset = VOCDetection(datadir, image_set='train', download=True,
                           transforms=Compose([VOCTargetTransform(classes),
                                               RandomResizedCrop((img_size, img_size), scale=(0.3, 1.0)),
                                               RandomHorizontalFlip(),
                                               convert_to_relative,
                                               ImageTransform(transforms.ColorJitter(brightness=0.3, contrast=0.3,
                                                                                     saturation=0.1, hue=0.02)),
                                               ImageTransform(transforms.ToTensor()), ImageTransform(normalize)]))

    print("Took", time.time() - st)

    print("Loading validation data")
    st = time.time()
    dataset_test = VOCDetection(datadir, image_set='val', download=True,
                                transforms=Compose([VOCTargetTransform(classes),
                                                    Resize(scale_size), CenterCrop(img_size),
                                                    convert_to_relative,
                                                    ImageTransform(transforms.ToTensor()), ImageTransform(normalize)]))

    print("Took", time.time() - st)
    print("Creating data loaders")
    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset, dataset_test, train_sampler, test_sampler
def init_seg(input_sizes,
             std,
             mean,
             dataset,
             test_base=None,
             test_label_id_map=None,
             city_aug=0):

    if dataset == 'voc':
        transform_test = Compose([
            ToTensor(),
            ZeroPad(size=input_sizes),
            Normalize(mean=mean, std=std)
        ])
    elif dataset == 'city' or dataset == 'gtav' or dataset == 'synthia':  # All the same size
        if city_aug == 2:  # ERFNet and ENet
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes, size_label=input_sizes),
                LabelMap(test_label_id_map)
            ])
        elif city_aug == 1:  # City big
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes, size_label=input_sizes),
                Normalize(mean=mean, std=std),
                LabelMap(test_label_id_map)
            ])
    else:
        raise ValueError

    # Not the actual test set (i.e. validation set)
    test_set = StandardSegmentationDataset(
        root=test_base,
        image_set='val',
        transforms=transform_test,
        data_set='city'
        if dataset == 'gtav' or dataset == 'synthia' else dataset)

    val_loader = torch.utils.data.DataLoader(dataset=test_set,
                                             batch_size=1,
                                             num_workers=0,
                                             shuffle=False)

    # Testing
    return val_loader
def init(batch_size, state, input_sizes, dataset, mean, std, base, workers=10):
    # Return data_loaders
    # depending on whether the state is
    # 0: training
    # 1: fast validation by mean IoU (validation set)
    # 2: just testing (test set)
    # 3: just testing (validation set)

    # Transformations
    # ! Can't use torchvision.Transforms.Compose
    transforms_test = Compose(
        [Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
         ToTensor(),
         Normalize(mean=mean, std=std)])
    transforms_train = Compose(
        [Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
         RandomRotation(degrees=3),
         ToTensor(),
         Normalize(mean=mean, std=std)])

    if state == 0:
        data_set = StandardLaneDetectionDataset(root=base, image_set='train', transforms=transforms_train,
                                                data_set=dataset)
        data_loader = torch.utils.data.DataLoader(dataset=data_set, batch_size=batch_size,
                                                  num_workers=workers, shuffle=True)
        validation_set = StandardLaneDetectionDataset(root=base, image_set='val',
                                                      transforms=transforms_test, data_set=dataset)
        validation_loader = torch.utils.data.DataLoader(dataset=validation_set, batch_size=batch_size * 4,
                                                        num_workers=workers, shuffle=False)
        return data_loader, validation_loader

    elif state == 1 or state == 2 or state == 3:
        image_sets = ['valfast', 'test', 'val']
        data_set = StandardLaneDetectionDataset(root=base, image_set=image_sets[state - 1],
                                                transforms=transforms_test, data_set=dataset)
        data_loader = torch.utils.data.DataLoader(dataset=data_set, batch_size=batch_size,
                                                  num_workers=workers, shuffle=False)
        return data_loader
    else:
        raise ValueError
Beispiel #8
0
def init_lane(input_sizes, dataset, mean, std, base, workers=0):
    transforms_test = Compose([
        Resize(size_image=input_sizes, size_label=input_sizes),
        ToTensor(),
        Normalize(mean=mean, std=std)
    ])
    validation_set = StandardLaneDetectionDataset(root=base,
                                                  image_set='val',
                                                  transforms=transforms_test,
                                                  data_set=dataset)
    validation_loader = torch.utils.data.DataLoader(dataset=validation_set,
                                                    batch_size=1,
                                                    num_workers=workers,
                                                    shuffle=False)
    return validation_loader
Beispiel #9
0
def get_transform_fixsize(train=True,img_size=416,
                  image_mean=None,image_std=None,advanced=False):
    if image_mean is None:image_mean = [0.485, 0.456, 0.406]
    if image_std is None:image_std = [0.229, 0.224, 0.225]
    if train:
        transforms = Compose(
            [
                Augment(advanced),
                Pad(),
                ToTensor(),
                Resize(img_size),
                RandomHorizontalFlip(0.5),
                Normalize(image_mean,image_std)
            ])
    else:
        transforms = Compose(
            [
                Pad(),
                ToTensor(),
                Resize(img_size),
                # RandomHorizontalFlip(0.5),
                Normalize(image_mean, image_std)
            ])
    return transforms
Beispiel #10
0
def load_data(datadir, img_size=416, crop_pct=0.875):
    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    scale_size = int(math.floor(img_size / crop_pct))

    print("Loading training data")
    st = time.time()
    train_set = VOCDetection(datadir,
                             image_set='train',
                             download=True,
                             transforms=Compose([
                                 VOCTargetTransform(VOC_CLASSES),
                                 RandomResizedCrop((img_size, img_size),
                                                   scale=(0.3, 1.0)),
                                 RandomHorizontalFlip(), convert_to_relative,
                                 ImageTransform(
                                     transforms.ColorJitter(brightness=0.3,
                                                            contrast=0.3,
                                                            saturation=0.1,
                                                            hue=0.02)),
                                 ImageTransform(transforms.ToTensor()),
                                 ImageTransform(normalize)
                             ]))

    print("Took", time.time() - st)

    print("Loading validation data")
    st = time.time()
    val_set = VOCDetection(datadir,
                           image_set='val',
                           download=True,
                           transforms=Compose([
                               VOCTargetTransform(VOC_CLASSES),
                               Resize(scale_size),
                               CenterCrop(img_size), convert_to_relative,
                               ImageTransform(transforms.ToTensor()),
                               ImageTransform(normalize)
                           ]))

    print("Took", time.time() - st)

    return train_set, val_set
Beispiel #11
0
 def get_testing_loader(img_root, label_root, file_list, batch_size,
                        img_size, num_class):
     transformed_dataset = VOCDataset(
         img_root,
         label_root,
         file_list,
         transform=transforms.Compose([
             Resize(img_size),
             ToTensor(),
             Normalize(imagenet_stats['mean'], imagenet_stats['std']),
             # GenOneHotLabel(num_class),
         ]))
     loader = DataLoader(
         transformed_dataset,
         batch_size,
         shuffle=False,
         num_workers=0,
         pin_memory=False,
     )
     return loader
Beispiel #12
0
 def get_training_loader(img_root, label_root, file_list, batch_size,
                         img_height, img_width, num_class):
     transformed_dataset = VOCTestDataset(
         img_root,
         label_root,
         file_list,
         transform=transforms.Compose([
             RandomHorizontalFlip(),
             Resize((img_height + 5, img_width + 5)),
             RandomCrop((img_height, img_width)),
             ToTensor(),
             Normalize(imagenet_stats['mean'], imagenet_stats['std']),
             # GenOneHotLabel(num_class),
         ]))
     loader = DataLoader(
         transformed_dataset,
         batch_size,
         shuffle=True,
         num_workers=0,
         pin_memory=False,
     )
     return loader
Beispiel #13
0
def pre_processing(input_path, output_path, device):
    """Run MonoDepthNN to compute depth maps.

    Args:
        input_path (str): path to input folder
        output_path (str): path to output folder
        model_path (str): path to saved model
    """
    transform = Compose([
        Resize(
            384,
            384,
            resize_target=None,
            keep_aspect_ratio=True,
            ensure_multiple_of=32,
            resize_method="upper_bound",
            image_interpolation_method=cv2.INTER_CUBIC,
        ),
        NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        PrepareForNet(),
    ])

    # model.to(device)
    # model.eval()

    # get input
    img_names = glob.glob(os.path.join(input_path, "*"))
    num_images = len(img_names)

    print("start processing")

    for ind, img_name in enumerate(img_names):
        # print("  processing {} ({}/{})".format(img_name, ind + 1, num_images))
        # input
        img = midas_utils.read_image(img_name)
        img_input = transform({"image": img})["image"]
        sample = torch.from_numpy(img_input).to(device).unsqueeze(0)
        yield (sample, img, img_name, num_images)
def init(batch_size,
         state,
         input_sizes,
         dataset,
         mean,
         std,
         base,
         workers=10,
         method='baseline'):
    # Return data_loaders
    # depending on whether the state is
    # 0: training
    # 1: fast validation by mean IoU (validation set)
    # 2: just testing (test set)
    # 3: just testing (validation set)

    # Transformations
    # ! Can't use torchvision.Transforms.Compose
    transforms_test = Compose([
        Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
        ToTensor(),
        Normalize(mean=mean, std=std)
    ])
    transforms_train = Compose([
        Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
        RandomRotation(degrees=3),
        ToTensor(),
        Normalize(mean=mean,
                  std=std,
                  normalize_target=True if method == 'lstr' else False)
    ])

    # Batch builder
    if method == 'lstr':
        collate_fn = dict_collate_fn
    else:
        collate_fn = None

    if state == 0:
        if method == 'lstr':
            if dataset == 'tusimple':
                data_set = TuSimple(root=base,
                                    image_set='train',
                                    transforms=transforms_train,
                                    padding_mask=True,
                                    process_points=True)
            elif dataset == 'culane':
                data_set = CULane(root=base,
                                  image_set='train',
                                  transforms=transforms_train,
                                  padding_mask=True,
                                  process_points=True)
            else:
                raise ValueError
        else:
            data_set = StandardLaneDetectionDataset(
                root=base,
                image_set='train',
                transforms=transforms_train,
                data_set=dataset)

        data_loader = torch.utils.data.DataLoader(dataset=data_set,
                                                  batch_size=batch_size,
                                                  collate_fn=collate_fn,
                                                  num_workers=workers,
                                                  shuffle=True)
        validation_set = StandardLaneDetectionDataset(
            root=base,
            image_set='val',
            transforms=transforms_test,
            data_set=dataset)
        validation_loader = torch.utils.data.DataLoader(dataset=validation_set,
                                                        batch_size=batch_size *
                                                        4,
                                                        num_workers=workers,
                                                        shuffle=False,
                                                        collate_fn=collate_fn)
        return data_loader, validation_loader

    elif state == 1 or state == 2 or state == 3:
        image_sets = ['valfast', 'test', 'val']
        if method == 'lstr':
            if dataset == 'tusimple':
                data_set = TuSimple(root=base,
                                    image_set=image_sets[state - 1],
                                    transforms=transforms_test,
                                    padding_mask=False,
                                    process_points=False)
            elif dataset == 'culane':
                data_set = CULane(root=base,
                                  image_set=image_sets[state - 1],
                                  transforms=transforms_test,
                                  padding_mask=False,
                                  process_points=False)
            else:
                raise ValueError
        else:
            data_set = StandardLaneDetectionDataset(
                root=base,
                image_set=image_sets[state - 1],
                transforms=transforms_test,
                data_set=dataset)
        data_loader = torch.utils.data.DataLoader(dataset=data_set,
                                                  batch_size=batch_size,
                                                  collate_fn=collate_fn,
                                                  num_workers=workers,
                                                  shuffle=False)
        return data_loader
    else:
        raise ValueError
Beispiel #15
0
def run(input_path, output_path, model_path):
    """Run MonoDepthNN to compute depth maps.

    Args:
        input_path (str): path to input folder
        output_path (str): path to output folder
        model_path (str): path to saved model
    """
    print("initialize")

    # select device
    device = "CUDA:0"
    #device = "CPU"
    print("device: %s" % device)

    # load network
    print("loading model...")
    model = onnx.load(model_path)
    print("checking model...")
    onnx.checker.check_model(model)
    print("preparing model...")
    tf_rep = onnx_tf.backend.prepare(model, device)

    print('inputs:', tf_rep.inputs)
    print('outputs:', tf_rep.outputs)

    resize_image = Resize(
                384,
                384,
                resize_target=None,
                keep_aspect_ratio=False,
                ensure_multiple_of=32,
                resize_method="upper_bound",
                image_interpolation_method=cv2.INTER_CUBIC,
            )
    
    def compose2(f1, f2):
        return lambda x: f2(f1(x))

    transform = compose2(resize_image, PrepareForNet())

    # get input
    img_names = glob.glob(os.path.join(input_path, "*"))
    num_images = len(img_names)

    # create output folder
    os.makedirs(output_path, exist_ok=True)

    print("start processing")

    for ind, img_name in enumerate(img_names):

        print("  processing {} ({}/{})".format(img_name, ind + 1, num_images))

        # input

        img = utils.read_image(img_name)
        img_input = transform({"image": img})["image"]

        # compute
        output = tf_rep.run(img_input.reshape(1, 3, 384, 384))
        prediction = np.array(output).reshape(384, 384)
        prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC)
       
        # output
        filename = os.path.join(
            output_path, os.path.splitext(os.path.basename(img_name))[0]
        )
        utils.write_depth(filename, prediction, bits=2)

    print("finished")
Beispiel #16
0
def run(input_path, output_path, model_path):
    """Run MonoDepthNN to compute depth maps.

    Args:
        input_path (str): path to input folder
        output_path (str): path to output folder
        model_path (str): path to saved model
    """
    print("initialize")

    # the runtime initialization will not allocate all memory on the device to avoid out of GPU memory
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                #tf.config.experimental.set_memory_growth(gpu, True)
                tf.config.experimental.set_virtual_device_configuration(
                    gpu, [
                        tf.config.experimental.VirtualDeviceConfiguration(
                            memory_limit=4000)
                    ])
        except RuntimeError as e:
            print(e)

    # load network
    graph_def = tf.compat.v1.GraphDef()
    with tf.io.gfile.GFile(model_path, 'rb') as f:
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')

    model_operations = tf.compat.v1.get_default_graph().get_operations()
    input_node = '0:0'
    output_layer = model_operations[len(model_operations) - 1].name + ':0'
    print("Last layer name: ", output_layer)

    resize_image = Resize(
        384,
        384,
        resize_target=None,
        keep_aspect_ratio=False,
        ensure_multiple_of=32,
        resize_method="upper_bound",
        image_interpolation_method=cv2.INTER_CUBIC,
    )

    def compose2(f1, f2):
        return lambda x: f2(f1(x))

    transform = compose2(resize_image, PrepareForNet())

    # get input
    img_names = glob.glob(os.path.join(input_path, "*"))
    num_images = len(img_names)

    # create output folder
    os.makedirs(output_path, exist_ok=True)

    print("start processing")

    with tf.compat.v1.Session() as sess:
        try:
            # load images
            for ind, img_name in enumerate(img_names):

                print("  processing {} ({}/{})".format(img_name, ind + 1,
                                                       num_images))

                # input
                img = utils.read_image(img_name)
                img_input = transform({"image": img})["image"]

                # compute
                prob_tensor = sess.graph.get_tensor_by_name(output_layer)
                prediction, = sess.run(prob_tensor, {input_node: [img_input]})
                prediction = prediction.reshape(384, 384)
                prediction = cv2.resize(prediction,
                                        (img.shape[1], img.shape[0]),
                                        interpolation=cv2.INTER_CUBIC)

                # output
                filename = os.path.join(
                    output_path,
                    os.path.splitext(os.path.basename(img_name))[0])
                utils.write_depth(filename, prediction, bits=2)

        except KeyError:
            print(
                "Couldn't find input node: ' + input_node + ' or output layer: "
                + output_layer + ".")
            exit(-1)

    print("finished")
Beispiel #17
0
def main(args):

    print(args)

    torch.backends.cudnn.benchmark = True

    # Data loading
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    crop_pct = 0.875
    scale_size = int(math.floor(args.img_size / crop_pct))

    train_loader, val_loader = None, None

    if not args.test_only:
        st = time.time()
        train_set = VOCDetection(datadir,
                                 image_set='train',
                                 download=True,
                                 transforms=Compose([
                                     VOCTargetTransform(VOC_CLASSES),
                                     RandomResizedCrop(
                                         (args.img_size, args.img_size),
                                         scale=(0.3, 1.0)),
                                     RandomHorizontalFlip(),
                                     convert_to_relative,
                                     ImageTransform(
                                         T.ColorJitter(brightness=0.3,
                                                       contrast=0.3,
                                                       saturation=0.1,
                                                       hue=0.02)),
                                     ImageTransform(T.ToTensor()),
                                     ImageTransform(normalize)
                                 ]))

        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            drop_last=True,
            collate_fn=collate_fn,
            sampler=RandomSampler(train_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(f"Training set loaded in {time.time() - st:.2f}s "
              f"({len(train_set)} samples in {len(train_loader)} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target)
        return

    if not (args.lr_finder or args.check_setup):
        st = time.time()
        val_set = VOCDetection(datadir,
                               image_set='val',
                               download=True,
                               transforms=Compose([
                                   VOCTargetTransform(VOC_CLASSES),
                                   Resize(scale_size),
                                   CenterCrop(args.img_size),
                                   convert_to_relative,
                                   ImageTransform(T.ToTensor()),
                                   ImageTransform(normalize)
                               ]))

        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=args.batch_size,
            drop_last=False,
            collate_fn=collate_fn,
            sampler=SequentialSampler(val_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(
            f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)"
        )

    model = detection.__dict__[args.model](args.pretrained,
                                           num_classes=len(VOC_CLASSES),
                                           pretrained_backbone=True)

    model_params = [p for p in model.parameters() if p.requires_grad]
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model_params,
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model_params,
                                     args.lr,
                                     betas=(0.95, 0.99),
                                     eps=1e-6,
                                     weight_decay=args.weight_decay)
    elif args.opt == 'radam':
        optimizer = holocron.optim.RAdam(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'ranger':
        optimizer = Lookahead(
            holocron.optim.RAdam(model_params,
                                 args.lr,
                                 betas=(0.95, 0.99),
                                 eps=1e-6,
                                 weight_decay=args.weight_decay))
    elif args.opt == 'tadam':
        optimizer = holocron.optim.TAdam(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)

    trainer = DetectionTrainer(model, train_loader, val_loader, None,
                               optimizer, args.device, args.output_file)

    if args.resume:
        print(f"Resuming {args.resume}")
        checkpoint = torch.load(args.resume, map_location='cpu')
        trainer.load(checkpoint)

    if args.test_only:
        print("Running evaluation")
        eval_metrics = trainer.evaluate()
        print(
            f"Loc error: {eval_metrics['loc_err']:.2%} | Clf error: {eval_metrics['clf_err']:.2%} | "
            f"Det error: {eval_metrics['det_err']:.2%}")
        return

    if args.lr_finder:
        print("Looking for optimal LR")
        trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100))
        trainer.plot_recorder()
        return

    if args.check_setup:
        print("Checking batch overfitting")
        is_ok = trainer.check_setup(args.freeze_until,
                                    args.lr,
                                    num_it=min(len(train_loader), 100))
        print(is_ok)
        return

    print("Start training")
    start_time = time.time()
    trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched)
    total_time_str = str(
        datetime.timedelta(seconds=int(time.time() - start_time)))
    print(f"Training time {total_time_str}")
Beispiel #18
0
D_optimizer = Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))

root = Path('/home/xingtong/ToolTrackingData/runs/colorization_ngf_32')
# root = Path('/home/xingtong/ToolTrackingData/runs/colorization')
try:
    root.mkdir()
except OSError:
    print("directory exists!")

adding_noise = True
gaussian_std = 0.05
n_epochs = 200
report_each = 1200

train_transform = DualCompose([
    Resize(size=img_size),
    HorizontalFlip(),
    VerticalFlip(),
    ColorizationNormalize()
])

valid_transform = DualCompose([Resize(size=img_size), ColorizationNormalize()])

fold = 0
train_file_names, val_file_names = get_split(fold=fold)

batch_size = 6
num_workers = 4
train_loader = DataLoader(dataset=ColorizationDataset(
    file_names=train_file_names, transform=train_transform, to_augment=True),
                          shuffle=True,
Beispiel #19
0
def main(args):

    print(args)

    torch.backends.cudnn.benchmark = True

    # Data loading
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    base_size = 320
    crop_size = 256
    min_size, max_size = int(0.5 * base_size), int(2.0 * base_size)

    train_loader, val_loader = None, None
    if not args.test_only:
        st = time.time()
        train_set = VOCSegmentation(args.data_path,
                                    image_set='train',
                                    download=True,
                                    transforms=Compose([
                                        RandomResize(min_size, max_size),
                                        RandomCrop(crop_size),
                                        RandomHorizontalFlip(0.5),
                                        ImageTransform(
                                            T.ColorJitter(brightness=0.3,
                                                          contrast=0.3,
                                                          saturation=0.1,
                                                          hue=0.02)),
                                        ToTensor(),
                                        ImageTransform(normalize)
                                    ]))

        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            drop_last=True,
            sampler=RandomSampler(train_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(f"Training set loaded in {time.time() - st:.2f}s "
              f"({len(train_set)} samples in {len(train_loader)} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target, ignore_index=255)
        return

    if not (args.lr_finder or args.check_setup):
        st = time.time()
        val_set = VOCSegmentation(args.data_path,
                                  image_set='val',
                                  download=True,
                                  transforms=Compose([
                                      Resize((crop_size, crop_size)),
                                      ToTensor(),
                                      ImageTransform(normalize)
                                  ]))

        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=args.batch_size,
            drop_last=False,
            sampler=SequentialSampler(val_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(
            f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)"
        )

    model = segmentation.__dict__[args.model](
        args.pretrained,
        not (args.pretrained),
        num_classes=len(VOC_CLASSES),
    )

    # Loss setup
    loss_weight = None
    if isinstance(args.bg_factor, float):
        loss_weight = torch.ones(len(VOC_CLASSES))
        loss_weight[0] = args.bg_factor
    if args.loss == 'crossentropy':
        criterion = nn.CrossEntropyLoss(weight=loss_weight, ignore_index=255)
    elif args.loss == 'label_smoothing':
        criterion = holocron.nn.LabelSmoothingCrossEntropy(weight=loss_weight,
                                                           ignore_index=255)
    elif args.loss == 'focal':
        criterion = holocron.nn.FocalLoss(weight=loss_weight, ignore_index=255)
    elif args.loss == 'mc':
        criterion = holocron.nn.MutualChannelLoss(weight=loss_weight,
                                                  ignore_index=255)

    # Optimizer setup
    model_params = [p for p in model.parameters() if p.requires_grad]
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model_params,
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model_params,
                                     args.lr,
                                     betas=(0.95, 0.99),
                                     eps=1e-6,
                                     weight_decay=args.weight_decay)
    elif args.opt == 'radam':
        optimizer = holocron.optim.RAdam(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adamp':
        optimizer = holocron.optim.AdamP(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adabelief':
        optimizer = holocron.optim.AdaBelief(model_params,
                                             args.lr,
                                             betas=(0.95, 0.99),
                                             eps=1e-6,
                                             weight_decay=args.weight_decay)

    trainer = SegmentationTrainer(model,
                                  train_loader,
                                  val_loader,
                                  criterion,
                                  optimizer,
                                  args.device,
                                  args.output_file,
                                  num_classes=len(VOC_CLASSES))
    if args.resume:
        print(f"Resuming {args.resume}")
        checkpoint = torch.load(args.resume, map_location='cpu')
        trainer.load(checkpoint)

    if args.show_preds:
        x, target = next(iter(train_loader))
        with torch.no_grad():
            if isinstance(args.device, int):
                x = x.cuda()
            trainer.model.eval()
            preds = trainer.model(x)
        plot_predictions(x.cpu(), preds.cpu(), target, ignore_index=255)
        return

    if args.test_only:
        print("Running evaluation")
        eval_metrics = trainer.evaluate()
        print(
            f"Validation loss: {eval_metrics['val_loss']:.4} (Mean IoU: {eval_metrics['mean_iou']:.2%})"
        )
        return

    if args.lr_finder:
        print("Looking for optimal LR")
        trainer.lr_find(args.freeze_until, num_it=min(len(train_loader), 100))
        trainer.plot_recorder()
        return

    if args.check_setup:
        print("Checking batch overfitting")
        is_ok = trainer.check_setup(args.freeze_until,
                                    args.lr,
                                    num_it=min(len(train_loader), 100))
        print(is_ok)
        return

    print("Start training")
    start_time = time.time()
    trainer.fit_n_epochs(args.epochs, args.lr, args.freeze_until, args.sched)
    total_time_str = str(
        datetime.timedelta(seconds=int(time.time() - start_time)))
    print(f"Training time {total_time_str}")
Beispiel #20
0
def run(input_path, output_path, model_path, model_type="large"):
    """Run MonoDepthNN to compute depth maps.

    Args:
        input_path (str): path to input folder
        output_path (str): path to output folder
        model_path (str): path to saved model
    """
    print("initialize")

    # select device
    device = "CUDA:0"
    #device = "CPU"
    print("device: %s" % device)

    # network resolution
    if model_type == "large":
        net_w, net_h = 384, 384
    elif model_type == "small":
        net_w, net_h = 256, 256
    else:
        print(f"model_type '{model_type}' not implemented, use: --model_type large")
        assert False

    # load network
    print("loading model...")
    model = rt.InferenceSession(model_path)
    input_name = model.get_inputs()[0].name
    output_name = model.get_outputs()[0].name

    resize_image = Resize(
                net_w,
                net_h,
                resize_target=None,
                keep_aspect_ratio=False,
                ensure_multiple_of=32,
                resize_method="upper_bound",
                image_interpolation_method=cv2.INTER_CUBIC,
            )
    
    def compose2(f1, f2):
        return lambda x: f2(f1(x))

    transform = compose2(resize_image, PrepareForNet())

    # get input
    img_names = glob.glob(os.path.join(input_path, "*"))
    num_images = len(img_names)

    # create output folder
    os.makedirs(output_path, exist_ok=True)

    print("start processing")

    for ind, img_name in enumerate(img_names):

        print("  processing {} ({}/{})".format(img_name, ind + 1, num_images))

        # input
        img = utils.read_image(img_name)
        img_input = transform({"image": img})["image"]

        # compute
        output = model.run([output_name], {input_name: img_input.reshape(1, 3, net_h, net_w).astype(np.float32)})[0]
        prediction = np.array(output).reshape(net_h, net_w)
        prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC)
       
        # output
        filename = os.path.join(
            output_path, os.path.splitext(os.path.basename(img_name))[0]
        )
        utils.write_depth(filename, prediction, bits=2)

    print("finished")
Beispiel #21
0
def main(args):

    print(args)

    torch.backends.cudnn.benchmark = True

    # Data loading
    normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    base_size = 320
    crop_size = 256
    min_size, max_size = int(0.5 * base_size), int(2.0 * base_size)

    interpolation_mode = InterpolationMode.BILINEAR

    train_loader, val_loader = None, None
    if not args.test_only:
        st = time.time()
        train_set = VOCSegmentation(args.data_path,
                                    image_set='train',
                                    download=True,
                                    transforms=Compose([
                                        RandomResize(min_size, max_size,
                                                     interpolation_mode),
                                        RandomCrop(crop_size),
                                        RandomHorizontalFlip(0.5),
                                        ImageTransform(
                                            T.ColorJitter(brightness=0.3,
                                                          contrast=0.3,
                                                          saturation=0.1,
                                                          hue=0.02)),
                                        ToTensor(),
                                        ImageTransform(normalize)
                                    ]))

        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=args.batch_size,
            drop_last=True,
            sampler=RandomSampler(train_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(f"Training set loaded in {time.time() - st:.2f}s "
              f"({len(train_set)} samples in {len(train_loader)} batches)")

    if args.show_samples:
        x, target = next(iter(train_loader))
        plot_samples(x, target, ignore_index=255)
        return

    if not (args.lr_finder or args.check_setup):
        st = time.time()
        val_set = VOCSegmentation(args.data_path,
                                  image_set='val',
                                  download=True,
                                  transforms=Compose([
                                      Resize((crop_size, crop_size),
                                             interpolation_mode),
                                      ToTensor(),
                                      ImageTransform(normalize)
                                  ]))

        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=args.batch_size,
            drop_last=False,
            sampler=SequentialSampler(val_set),
            num_workers=args.workers,
            pin_memory=True,
            worker_init_fn=worker_init_fn)

        print(
            f"Validation set loaded in {time.time() - st:.2f}s ({len(val_set)} samples in {len(val_loader)} batches)"
        )

    if args.source.lower() == 'holocron':
        model = segmentation.__dict__[args.arch](args.pretrained,
                                                 num_classes=len(VOC_CLASSES))
    elif args.source.lower() == 'torchvision':
        model = tv_segmentation.__dict__[args.arch](
            args.pretrained, num_classes=len(VOC_CLASSES))

    # Loss setup
    loss_weight = None
    if isinstance(args.bg_factor, float) and args.bg_factor != 1:
        loss_weight = torch.ones(len(VOC_CLASSES))
        loss_weight[0] = args.bg_factor
    if args.loss == 'crossentropy':
        criterion = nn.CrossEntropyLoss(weight=loss_weight,
                                        ignore_index=255,
                                        label_smoothing=args.label_smoothing)
    elif args.loss == 'focal':
        criterion = holocron.nn.FocalLoss(weight=loss_weight, ignore_index=255)
    elif args.loss == 'mc':
        criterion = holocron.nn.MutualChannelLoss(weight=loss_weight,
                                                  ignore_index=255,
                                                  xi=3)

    # Optimizer setup
    model_params = [p for p in model.parameters() if p.requires_grad]
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model_params,
                                    args.lr,
                                    momentum=0.9,
                                    weight_decay=args.weight_decay)
    elif args.opt == 'radam':
        optimizer = holocron.optim.RAdam(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adamp':
        optimizer = holocron.optim.AdamP(model_params,
                                         args.lr,
                                         betas=(0.95, 0.99),
                                         eps=1e-6,
                                         weight_decay=args.weight_decay)
    elif args.opt == 'adabelief':
        optimizer = holocron.optim.AdaBelief(model_params,
                                             args.lr,
                                             betas=(0.95, 0.99),
                                             eps=1e-6,
                                             weight_decay=args.weight_decay)

    log_wb = lambda metrics: wandb.log(metrics) if args.wb else None
    trainer = SegmentationTrainer(model,
                                  train_loader,
                                  val_loader,
                                  criterion,
                                  optimizer,
                                  args.device,
                                  args.output_file,
                                  num_classes=len(VOC_CLASSES),
                                  amp=args.amp,
                                  on_epoch_end=log_wb)
    if args.resume:
        print(f"Resuming {args.resume}")
        checkpoint = torch.load(args.resume, map_location='cpu')
        trainer.load(checkpoint)

    if args.show_preds:
        x, target = next(iter(train_loader))
        with torch.no_grad():
            if isinstance(args.device, int):
                x = x.cuda()
            trainer.model.eval()
            preds = trainer.model(x)
        plot_predictions(x.cpu(), preds.cpu(), target, ignore_index=255)
        return

    if args.test_only:
        print("Running evaluation")
        eval_metrics = trainer.evaluate()
        print(
            f"Validation loss: {eval_metrics['val_loss']:.4} (Mean IoU: {eval_metrics['mean_iou']:.2%})"
        )
        return

    if args.lr_finder:
        print("Looking for optimal LR")
        trainer.lr_find(args.freeze_until,
                        norm_weight_decay=args.norm_weight_decay,
                        num_it=min(len(train_loader), 100))
        trainer.plot_recorder()
        return

    if args.check_setup:
        print("Checking batch overfitting")
        is_ok = trainer.check_setup(args.freeze_until,
                                    args.lr,
                                    norm_weight_decay=args.norm_weight_decay,
                                    num_it=min(len(train_loader), 100))
        print(is_ok)
        return

    # Training monitoring
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    exp_name = f"{args.arch}-{current_time}" if args.name is None else args.name

    # W&B
    if args.wb:

        run = wandb.init(name=exp_name,
                         project="holocron-semantic-segmentation",
                         config={
                             "learning_rate": args.lr,
                             "scheduler": args.sched,
                             "weight_decay": args.weight_decay,
                             "epochs": args.epochs,
                             "batch_size": args.batch_size,
                             "architecture": args.arch,
                             "source": args.source,
                             "input_size": 256,
                             "optimizer": args.opt,
                             "dataset": "Pascal VOC2012 Segmentation",
                             "loss": args.loss,
                         })

    print("Start training")
    start_time = time.time()
    trainer.fit_n_epochs(args.epochs,
                         args.lr,
                         args.freeze_until,
                         args.sched,
                         norm_weight_decay=args.norm_weight_decay)
    total_time_str = str(
        datetime.timedelta(seconds=int(time.time() - start_time)))
    print(f"Training time {total_time_str}")

    if args.wb:
        run.finish()
Beispiel #22
0
def run(input_path, output_path, model_path):
    """Run MonoDepthNN to compute depth maps.

    Args:
        input_path (str): path to input folder
        output_path (str): path to output folder
        model_path (str): path to saved model
    """
    print("initialize")

    # select device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("device: %s" % device)

    # load network
    all_props = parse_model_cfg(cfg)
    yolo_props = []
    for item in all_props:
        if item['type'] == 'yolo':
            yolo_props.append(item)
    # print(yolo_props[0])
    model = PPENet(yolo_props[0]).to(device)
    model.inference = True

    if model_path.endswith('.pt'):  # pytorch format
        # possible weights are '*.pt', 'yolov3-spp.pt', 'yolov3-tiny.pt' etc.
        midas_chkpt = torch.load(model_path, map_location=device)

        # load model
        try:
            # yolo_chkpt['model'] = {k: v for k, v in yolo_chkpt['model'].items() if model.state_dict()[k].numel() == v.numel()}
            # model.load_state_dict(yolo_chkpt['model'], strict=False)
            # own_state = load_my_state_dict(chkpt, model)
            # model.load_state_dict(own_state, strict=False)

            model_dict = model.state_dict()
            # 1. filter out unnecessary keys
            # midas_weighs_dict = {k: v for k, v in midas_chkpt.items() if k in model_dict}
            for k, v in midas_chkpt['model'].items():
                if k in model_dict:
                    model_dict[k] = v
            # 2. overwrite entries in the existing state dict
            # model_dict.update(midas_weighs_dict)
            # 3. load the new state dict
            model.load_state_dict(model_dict)
            print("Loaded Model weights successfully")

        except KeyError as e:
            s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s. " \
                "See https://github.com/ultralytics/yolov3/issues/657" % (opt.midas_weights, opt.cfg, opt.midas_weights)
            raise KeyError(s) from e

    transform = Compose([
        Resize(
            384,
            384,
            resize_target=None,
            keep_aspect_ratio=True,
            ensure_multiple_of=32,
            resize_method="upper_bound",
            image_interpolation_method=cv2.INTER_CUBIC,
        ),
        NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        PrepareForNet(),
    ])

    model.to(device)
    model.eval()

    # get input
    img_names = glob.glob(os.path.join(input_path, "*"))
    num_images = len(img_names)

    # create output folder
    os.makedirs(output_path, exist_ok=True)

    print("start processing")

    for ind, img_name in enumerate(img_names):

        print("  processing {} ({}/{})".format(img_name, ind + 1, num_images))

        # input

        img = midas_utils.read_image(img_name)
        img_input = transform({"image": img})["image"]
        # print(f"DBG run img.shape - {img.shape}")
        # compute
        with torch.no_grad():
            sample = torch.from_numpy(img_input).to(device).unsqueeze(0)
            prediction, _ = model(sample)
            # print(f"DBG prediction.shape - {prediction.shape}")
            # print(f"prediction.min() - {prediction.min()}")
            # print(f"prediction.max() - {prediction.max()}")
            prediction = (torch.nn.functional.interpolate(
                prediction.unsqueeze(1),
                size=img.shape[:2],
                mode="bicubic",
                align_corners=False,
            ).squeeze().cpu().numpy())

        # output
        filename = os.path.join(
            output_path,
            os.path.splitext(os.path.basename(img_name))[0])
        midas_utils.write_depth(filename, prediction, bits=2)

    print("finished")
def main(argv):
    params = args_parsing(cmd_args_parsing(argv))
    root, experiment_name, image_size, batch_size, lr, n_epochs, log_dir, checkpoint_path = (
        params['root'], params['experiment_name'], params['image_size'],
        params['batch_size'], params['lr'], params['n_epochs'],
        params['log_dir'], params['checkpoint_path'])

    train_val_split(os.path.join(root, DATASET_TABLE_PATH))
    dataset = pd.read_csv(os.path.join(root, DATASET_TABLE_PATH))

    pre_transforms = torchvision.transforms.Compose(
        [Resize(size=image_size), ToTensor()])
    batch_transforms = torchvision.transforms.Compose(
        [BatchEncodeSegmentaionMap()])
    augmentation_batch_transforms = torchvision.transforms.Compose([
        BatchToPILImage(),
        BatchHorizontalFlip(p=0.5),
        BatchRandomRotation(degrees=10),
        BatchRandomScale(scale=(1.0, 2.0)),
        BatchBrightContrastJitter(brightness=(0.5, 2.0), contrast=(0.5, 2.0)),
        BatchToTensor(),
        BatchEncodeSegmentaionMap()
    ])

    train_dataset = SegmentationDataset(
        dataset=dataset[dataset['phase'] == 'train'], transform=pre_transforms)

    train_sampler = SequentialSampler(train_dataset)
    train_batch_sampler = BatchSampler(train_sampler, batch_size)
    train_collate = collate_transform(augmentation_batch_transforms)
    train_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_sampler=train_batch_sampler,
        collate_fn=train_collate)

    val_dataset = SegmentationDataset(
        dataset=dataset[dataset['phase'] == 'val'], transform=pre_transforms)

    val_sampler = SequentialSampler(val_dataset)
    val_batch_sampler = BatchSampler(val_sampler, batch_size)
    val_collate = collate_transform(batch_transforms)
    val_dataloader = torch.utils.data.DataLoader(
        dataset=val_dataset,
        batch_sampler=val_batch_sampler,
        collate_fn=val_collate)

    # model = Unet_with_attention(1, 2, image_size[0], image_size[1]).to(device)
    # model = UNet(1, 2).to(device)
    # model = UNetTC(1, 2).to(device)

    model = UNetFourier(1, 2, image_size, fourier_layer='linear').to(device)

    writer, experiment_name, best_model_path = setup_experiment(
        model.__class__.__name__, log_dir, experiment_name)

    new_checkpoint_path = os.path.join(root, 'checkpoints',
                                       experiment_name + '_latest.pth')
    best_checkpoint_path = os.path.join(root, 'checkpoints',
                                        experiment_name + '_best.pth')
    os.makedirs(os.path.dirname(new_checkpoint_path), exist_ok=True)

    if checkpoint_path is not None:
        checkpoint_path = os.path.join(root, 'checkpoints', checkpoint_path)
        print(f"\nLoading checkpoint from {checkpoint_path}.\n")
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = None
    best_model_path = os.path.join(root, best_model_path)
    print(f"Experiment name: {experiment_name}")
    print(f"Model has {count_parameters(model):,} trainable parameters")
    print()

    criterion = CombinedLoss(
        [CrossEntropyLoss(),
         GeneralizedDiceLoss(weighted=True)], [0.4, 0.6])
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           'min',
                                                           factor=0.5,
                                                           patience=5)
    metric = DiceMetric()
    weighted_metric = DiceMetric(weighted=True)

    print(
        "To see the learning process, use command in the new terminal:\ntensorboard --logdir <path to log directory>"
    )
    print()
    train(model, train_dataloader, val_dataloader, criterion, optimizer,
          scheduler, metric, weighted_metric, n_epochs, device, writer,
          best_model_path, best_checkpoint_path, checkpoint,
          new_checkpoint_path)
Beispiel #24
0
    def train(self, folder, n_epochs, tsize=512):
        """
        train the model

        parameters
        ----------
        folder: string
            path to the training set folder

        n_epochs: int
            number of training epochs

        tsize: int
            segmentation tile size
        """

        if n_epochs < 1:
            raise ValueError("'n_epochs' must be greater or equal to 1")

        # training parameters
        learning_rate = 0.0001
        criterion = SegLoss(self._c_weights)  # custom loss
        optimizer = torch.optim.Adam(self._model.parameters(),
                                     lr=learning_rate)

        # inits
        self._model.train()
        tf_resize = Resize()
        dataset = ImgDataset(folder)

        # training loop
        for epoch in range(n_epochs):
            img_count = 0
            sum_loss = 0
            tile_count = 0

            for image, mask in dataset:
                tile_dataset = TileDataset(image,
                                           mask,
                                           tsize=tsize,
                                           mask_merge=(self._n_classes <= 2))
                tile_loader = DataLoader(tile_dataset, batch_size=1)

                for i, (x, y, _) in enumerate(tile_loader):
                    # batch
                    x = x.to(self._device)
                    y = y.to(self._device)

                    # forward pass
                    y_pred = self._model(x)
                    if y_pred.shape != y.shape:
                        y = tf_resize(y, (y_pred.shape[2], y_pred.shape[3]))
                    loss = criterion(y_pred, y)
                    sum_loss += loss.item()

                    # backward pass
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    # verbose
                    tile_count += 1
                    print("epoch: " + str(epoch + 1) + "/" + str(n_epochs) +
                          ", image: " + str(img_count + 1) + "/" +
                          str(len(dataset)) + ", iteration: " + str(i + 1) +
                          "/" + str(len(tile_dataset)) + ", avg_loss: " +
                          str(round(sum_loss / tile_count, 4)))
                img_count += 1
        print("training done")
Beispiel #25
0
    def segment(self, folder, dest="", tsize=512, transform=None):
        """
        segment a folder of images

        parameters
        ----------
        folder: string
            folder containing the images to segment

        dest: string
            folder where the predicted masks are written

        tsize: int
            segmentation tile size

        transform: Transform
            transform applied to the predicted masks
        """
        
        # inits
        self.set_eval()
        tf_resize = Resize()
        dataset = ImgDataset(folder)
        img_count = 0
        mask_p = None
        sum_jaccard = 0

        for image, mask in dataset:
            img_count += 1
            sum_intersion, sum_union = 0, 0
            tile_dataset = TileDataset(image, mask, tsize=tsize,
                                       mask_merge=(self._n_classes <= 2))
            dl = DataLoader(dataset=tile_dataset, batch_size=1)

            for tile, tile_mask, tile_id in dl:
                # compute tile position and size without padding
                offset = tile_dataset.topology.tile_offset(tile_id)
                off_x, off_y = offset[1].item(), offset[0].item()
                t_h, t_w = tsize, tsize
                if off_x + tsize > image.height:
                    t_h = image.height - off_x
                if off_y + tsize > image.width:
                    t_w = image.width - off_y

                # compute predicted tile mask
                tile_mask_p = self.predict(tile, transform).cpu()
                # resize if necessary
                if tile_mask_p.shape != tile_mask.shape:
                    tile_mask_p = tf_resize(tile_mask_p, (tsize, tsize))
                # select area without padding
                tile_mask_p = tile_mask_p[:,1:,:t_h,:t_w].int()
                tile_mask = tile_mask[:,1:,:t_h,:t_w].int()

                # compute intersection and union
                sum_intersion += torch.sum(torch.bitwise_and(tile_mask_p, tile_mask))
                sum_union += torch.sum(torch.bitwise_or(tile_mask_p, tile_mask))
                #TODO tile overlap and merging should be taken into account when computing IoU
                #TODO this is computed on the whole tensor -> channel wise better ?
                
                # write the predicted tile mask to the predicted mask
                # bitwise_or is used for tiles merging
                if mask_p is None:
                    mask_p = torch.zeros(tile_mask_p.shape[1], image.height, 
                                         image.width, dtype=torch.int)
                mask_p[:, off_x:(off_x+t_h), off_y:(off_y+t_w)] = torch.bitwise_or(
                    mask_p[:, off_x:(off_x+t_h), off_y:(off_y+t_w)], tile_mask_p.squeeze(0))

            # write the predicted mask channels to files
            if dest == "":
                dest = folder
            if not os.path.exists(dest):
                os.makedirs(dest)
            for i in range(mask_p.shape[0]):
                filename = dest + f'/{img_count}_yp_{i+1}.png'
                cv2.imwrite(filename, mask_p[i].numpy()*255)

            # compute jaccard
            jaccard = 1
            if sum_union != 0:
                jaccard = sum_intersion / sum_union
            sum_jaccard += jaccard
            print(f'image: {img_count}/{len(dataset)}, jaccard: {jaccard:.4f}')
        print(f'average jaccard: {(sum_jaccard/len(dataset)):.4f}')
def init(batch_size, state, input_sizes, std, mean, dataset, city_aug=0):
    # Return data_loaders
    # depending on whether the state is
    # 1: training
    # 2: just testing

    # Transformations
    # ! Can't use torchvision.Transforms.Compose
    if dataset == 'voc':
        base = base_voc
        workers = 4
        transform_train = Compose([
            ToTensor(),
            RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
            RandomCrop(size=input_sizes[0]),
            RandomHorizontalFlip(flip_prob=0.5),
            Normalize(mean=mean, std=std)
        ])
        transform_test = Compose([
            ToTensor(),
            ZeroPad(size=input_sizes[2]),
            Normalize(mean=mean, std=std)
        ])
    elif dataset == 'city' or dataset == 'gtav' or dataset == 'synthia':  # All the same size
        if dataset == 'city':
            base = base_city
        elif dataset == 'gtav':
            base = base_gtav
        else:
            base = base_synthia
        outlier = False if dataset == 'city' else True  # GTAV has f****d up label ID
        workers = 8

        if city_aug == 3:  # SYNTHIA & GTAV
            if dataset == 'gtav':
                transform_train = Compose([
                    ToTensor(),
                    Resize(size_label=input_sizes[1],
                           size_image=input_sizes[1]),
                    RandomCrop(size=input_sizes[0]),
                    RandomHorizontalFlip(flip_prob=0.5),
                    Normalize(mean=mean, std=std),
                    LabelMap(label_id_map_city, outlier=outlier)
                ])
            else:
                transform_train = Compose([
                    ToTensor(),
                    RandomCrop(size=input_sizes[0]),
                    RandomHorizontalFlip(flip_prob=0.5),
                    Normalize(mean=mean, std=std),
                    LabelMap(label_id_map_synthia, outlier=outlier)
                ])
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
                Normalize(mean=mean, std=std),
                LabelMap(label_id_map_city)
            ])
        elif city_aug == 2:  # ERFNet
            transform_train = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
                LabelMap(label_id_map_city, outlier=outlier),
                RandomTranslation(trans_h=2, trans_w=2),
                RandomHorizontalFlip(flip_prob=0.5)
            ])
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[0], size_label=input_sizes[2]),
                LabelMap(label_id_map_city)
            ])
        elif city_aug == 1:  # City big
            transform_train = Compose([
                ToTensor(),
                RandomCrop(size=input_sizes[0]),
                LabelMap(label_id_map_city, outlier=outlier),
                RandomTranslation(trans_h=2, trans_w=2),
                RandomHorizontalFlip(flip_prob=0.5),
                Normalize(mean=mean, std=std)
            ])
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
                Normalize(mean=mean, std=std),
                LabelMap(label_id_map_city)
            ])
        else:  # Standard city
            transform_train = Compose([
                ToTensor(),
                RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
                RandomCrop(size=input_sizes[0]),
                RandomHorizontalFlip(flip_prob=0.5),
                Normalize(mean=mean, std=std),
                LabelMap(label_id_map_city, outlier=outlier)
            ])
            transform_test = Compose([
                ToTensor(),
                Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
                Normalize(mean=mean, std=std),
                LabelMap(label_id_map_city)
            ])
    else:
        raise ValueError

    # Not the actual test set (i.e. validation set)
    test_set = StandardSegmentationDataset(
        root=base_city if dataset == 'gtav' or dataset == 'synthia' else base,
        image_set='val',
        transforms=transform_test,
        data_set='city'
        if dataset == 'gtav' or dataset == 'synthia' else dataset)
    if city_aug == 1:
        val_loader = torch.utils.data.DataLoader(dataset=test_set,
                                                 batch_size=1,
                                                 num_workers=workers,
                                                 shuffle=False)
    else:
        val_loader = torch.utils.data.DataLoader(dataset=test_set,
                                                 batch_size=batch_size,
                                                 num_workers=workers,
                                                 shuffle=False)

    # Testing
    if state == 1:
        return val_loader
    else:
        # Training
        train_set = StandardSegmentationDataset(
            root=base,
            image_set='trainaug' if dataset == 'voc' else 'train',
            transforms=transform_train,
            data_set=dataset)
        train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                                   batch_size=batch_size,
                                                   num_workers=workers,
                                                   shuffle=True)
        return train_loader, val_loader
Beispiel #27
0
def init(batch_size, state, split, input_sizes, sets_id, std, mean, keep_scale, reverse_channels, data_set,
         valtiny, no_aug):
    # Return data_loaders/data_loader
    # depending on whether the split is
    # 1: semi-supervised training
    # 2: fully-supervised training
    # 3: Just testing

    # Transformations (compatible with unlabeled data/pseudo labeled data)
    # ! Can't use torchvision.Transforms.Compose
    if data_set == 'voc':
        base = base_voc
        workers = 4
        transform_train = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
             RandomCrop(size=input_sizes[0]),
             RandomHorizontalFlip(flip_prob=0.5),
             Normalize(mean=mean, std=std)])
        if no_aug:
            transform_train_pseudo = Compose(
                [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
                 Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
                 Normalize(mean=mean, std=std)])
        else:
            transform_train_pseudo = Compose(
                [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
                 RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
                 RandomCrop(size=input_sizes[0]),
                 RandomHorizontalFlip(flip_prob=0.5),
                 Normalize(mean=mean, std=std)])
        transform_pseudo = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
             Normalize(mean=mean, std=std)])
        transform_test = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             ZeroPad(size=input_sizes[2]),
             Normalize(mean=mean, std=std)])
    elif data_set == 'city':  # All the same size (whole set is down-sampled by 2)
        base = base_city
        workers = 8
        transform_train = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
             RandomCrop(size=input_sizes[0]),
             RandomHorizontalFlip(flip_prob=0.5),
             Normalize(mean=mean, std=std),
             LabelMap(label_id_map_city)])
        if no_aug:
            transform_train_pseudo = Compose(
                [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
                 Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
                 Normalize(mean=mean, std=std)])
        else:
            transform_train_pseudo = Compose(
                [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
                 RandomResize(min_size=input_sizes[0], max_size=input_sizes[1]),
                 RandomCrop(size=input_sizes[0]),
                 RandomHorizontalFlip(flip_prob=0.5),
                 Normalize(mean=mean, std=std)])
        transform_pseudo = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             Resize(size_image=input_sizes[0], size_label=input_sizes[0]),
             Normalize(mean=mean, std=std),
             LabelMap(label_id_map_city)])
        transform_test = Compose(
            [ToTensor(keep_scale=keep_scale, reverse_channels=reverse_channels),
             Resize(size_image=input_sizes[2], size_label=input_sizes[2]),
             Normalize(mean=mean, std=std),
             LabelMap(label_id_map_city)])
    else:
        base = ''

    # Not the actual test set (i.e.validation set)
    test_set = StandardSegmentationDataset(root=base, image_set='valtiny' if valtiny else 'val',
                                           transforms=transform_test, label_state=0, data_set=data_set)
    val_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size, num_workers=workers, shuffle=False)

    # Testing
    if state == 3:
        return val_loader
    else:
        # Fully-supervised training
        if state == 2:
            labeled_set = StandardSegmentationDataset(root=base, image_set=(str(split) + '_labeled_' + str(sets_id)),
                                                      transforms=transform_train, label_state=0, data_set=data_set)
            labeled_loader = torch.utils.data.DataLoader(dataset=labeled_set, batch_size=batch_size,
                                                         num_workers=workers, shuffle=True)
            return labeled_loader, val_loader

        # Semi-supervised training
        elif state == 1:
            pseudo_labeled_set = StandardSegmentationDataset(root=base, data_set=data_set,
                                                             image_set=(str(split) + '_unlabeled_' + str(sets_id)),
                                                             transforms=transform_train_pseudo, label_state=1)
            reference_set = SegmentationLabelsDataset(root=base, image_set=(str(split) + '_unlabeled_' + str(sets_id)),
                                                      data_set=data_set)
            reference_loader = torch.utils.data.DataLoader(dataset=reference_set, batch_size=batch_size,
                                                           num_workers=workers, shuffle=False)
            unlabeled_set = StandardSegmentationDataset(root=base, data_set=data_set,
                                                        image_set=(str(split) + '_unlabeled_' + str(sets_id)),
                                                        transforms=transform_pseudo, label_state=2)
            labeled_set = StandardSegmentationDataset(root=base, data_set=data_set,
                                                      image_set=(str(split) + '_labeled_' + str(sets_id)),
                                                      transforms=transform_train, label_state=0)

            unlabeled_loader = torch.utils.data.DataLoader(dataset=unlabeled_set, batch_size=batch_size,
                                                           num_workers=workers, shuffle=False)

            pseudo_labeled_loader = torch.utils.data.DataLoader(dataset=pseudo_labeled_set,
                                                                batch_size=int(batch_size / 2),
                                                                num_workers=workers, shuffle=True)
            labeled_loader = torch.utils.data.DataLoader(dataset=labeled_set,
                                                         batch_size=int(batch_size / 2),
                                                         num_workers=workers, shuffle=True)
            return labeled_loader, pseudo_labeled_loader, unlabeled_loader, val_loader, reference_loader

        else:
            # Support unsupervised learning here if that's what you want
            raise ValueError
Beispiel #28
0
from dataset import get_split
from transforms import (DualCompose, Resize, MaskOnly, MaskShiftScaleRotate,
                        MaskShift, Normalize, HorizontalFlip, VerticalFlip)
import utils

mask_file_names = dataset.get_file_names('/home/xingtong/CAD_models/knife1',
                                         'knife_mask')
print(mask_file_names)

color_file_names = dataset.get_file_names('/home/xingtong/CAD_models/knife1',
                                          'color_')
print(color_file_names)

img_size = 128
train_transform = DualCompose([
    Resize(size=img_size),
    HorizontalFlip(),
    VerticalFlip(),
    Normalize(normalize_mask=True),
    MaskOnly([MaskShiftScaleRotate(scale_upper=4.0),
              MaskShift(limit=50)])
])

num_workers = 4
batch_size = 6
train_loader = DataLoader(dataset=dataset.CADDataset(
    color_file_names=color_file_names,
    mask_file_names=mask_file_names,
    transform=train_transform),
                          shuffle=True,
                          num_workers=num_workers,
Beispiel #29
0
    epoch_to_use = 44
    use_previous_model = True
    batch_size = 8
    num_workers = 8
    n_epochs = 1500
    gamma = 0.99

    img_width = 768
    img_height = 768

    display_img_height = 300
    display_img_width = 300

    test_transform = DualCompose(
        [MaskLabel(),
         Resize(w=img_width, h=img_height),
         Normalize()])

    input_path = "../datasets/lumi/A/test"
    label_path = "../datasets/lumi/B/test"

    input_file_names = utils.read_lumi_filenames(input_path)
    label_file_names = utils.read_lumi_filenames(label_path)

    dataset = dataset.LumiDataset(input_filenames=input_file_names,
                                  label_filenames=label_file_names,
                                  transform=test_transform)
    loader = DataLoader(dataset=dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        num_workers=num_workers)
def visualizeImage3D(vol_path: str, proc_vol_path: str):
    img_proc = nib.load(proc_vol_path)
    img_unproc = nib.load(vol_path)

    data_proc = img_proc.get_fdata()
    data_unproc = img_unproc.get_fdata()

    resize_proc = Resize((data_proc, img_proc.affine, img_proc.header),
                         (50, 50, 26))
    resize_unproc = Resize((data_unproc, img_unproc.affine, img_unproc.header),
                           (50, 50, 26))

    data_proc = resize_proc()
    data_unproc = resize_unproc()

    x_p = []
    y_p = []
    z_p = []
    c_p = []
    x = []
    y = []
    z = []
    c = []

    for sur in range(data_proc.shape[0]):
        for row in range(data_proc.shape[1]):
            for col in range(data_proc.shape[2]):
                x_p.append(sur)
                y_p.append(row)
                z_p.append(col)
                c_p.append(data_proc[sur][row][col])

    for sur in range(data_unproc.shape[0]):
        for row in range(data_unproc.shape[1]):
            for col in range(data_unproc.shape[2]):
                x.append(sur)
                y.append(row)
                z.append(col)
                c.append(data_unproc[sur][row][col])

    fig_proc = plt.figure()
    plt.suptitle("3D MRI processed MRI image")
    fig_unproc = plt.figure()
    plt.suptitle("3D MRI unprocessed MRI image")
    ax_proc = fig_proc.add_subplot(111, projection='3d')
    ax_unproc = fig_unproc.add_subplot(111, projection='3d')

    # Choose colormap
    cmap = pl.cm.RdBu

    # Get the colormap colors
    my_cmap = cmap(np.arange(cmap.N))

    # Set alpha
    my_cmap[:, -1] = np.linspace(0, 1, cmap.N)

    # Create new colormap
    my_cmap = ListedColormap(my_cmap)

    fig_p = ax_proc.scatter(x_p, y_p, z_p, c=c_p, cmap=my_cmap)
    fig_u = ax_unproc.scatter(x, y, z, c=c, cmap=my_cmap)

    fig_proc.colorbar(fig_p)
    fig_unproc.colorbar(fig_u)

    move_figure(fig_unproc, 830, 83)
    move_figure(fig_proc, 0, 83)

    plt.show()