Exemplo n.º 1
0
def augment_data(save_dir):
    """
    A special that implemnets the data augmentation pipeline.
    :param save_dir: Where to save the augmented data?
    :return:
    """

    seed = 1337
    random.seed(seed)
    start_time = time.time()
    print(f"====== Augmenting data. Seed set at {seed} ======")

    data_file = h5py.File(os.path.join(save_dir, 'data_file.h5'), 'r')
    data_shape = data_file['data/data'].shape

    data_aug = np.zeros(shape=data_shape, dtype=np.float32)

    n_samples = data_shape[0]
    img_channels, img_height, img_width, img_depth = data_shape[1:5]

    try:
        aug = alb.load(os.path.join(save_dir, 'aug_pipeline_1.json'))
    except FileNotFoundError:
        print("Pipeline not found. Generating One ...")
        aug = Compose([
            OneOf([VerticalFlip(p=1), HorizontalFlip(p=1)], p=1),
            OneOf([
                ElasticTransform(p=1, sigma=6, alpha_affine=4, alpha=75),
                GridDistortion(p=1),
                OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)
            ],
                  p=0.8)
        ])

        alb.save(aug, os.path.join(save_dir, 'aug_pipeline_1.json'))

    for data_idx in np.arange(n_samples):
        img = data_file['data/data'][data_idx, ...]
        img = img.reshape(img_channels, img_height, img_width, -1)
        img_aug = aug(image=img[0,
                                ...])['image'].reshape(img_channels,
                                                       img_height, img_width,
                                                       img_depth, -1)

        data_aug[data_idx, ...] = img_aug

        del img_aug
        del img

    data_file.close()

    with h5py.File(os.path.join(save_dir, 'data_aug.h5'), 'w') as file:
        file.create_dataset('data/data', data=data_aug, dtype=np.float32)

    print(
        f"====== Finished augmentation. Time taken: {time.time() - start_time}s ======"
    )
Exemplo n.º 2
0
    def save_policy(self):

        transform = self.models["policy"].create_transform(
            input_dtype=self.cfg.data.input_dtype,
            preprocessing_transforms=self.get_preprocessing_transforms(),
        )
        policy_save_path = self.paths["policy_dir"] / f"epoch_{self.epoch}.json"
        A.save(transform, str(policy_save_path))
        symlink(policy_save_path, self.paths["latest_policy_path"])
        log.info(
            f"Policy is saved to {policy_save_path}. "
            f"{self.paths['latest_policy_path']} now also points to this policy file."
        )
def test_augmentations_serialization_to_file_with_custom_parameters(
    augmentation_cls, params, p, seed, image, mask, always_apply, data_format
):
    with patch("builtins.open", OpenMock()):
        aug = augmentation_cls(p=p, always_apply=always_apply, **params)
        filepath = "serialized.{}".format(data_format)
        A.save(aug, filepath, data_format=data_format)
        deserialized_aug = A.load(filepath, data_format=data_format)
        set_seed(seed)
        aug_data = aug(image=image, mask=mask)
        set_seed(seed)
        deserialized_aug_data = deserialized_aug(image=image, mask=mask)
        assert np.array_equal(aug_data["image"], deserialized_aug_data["image"])
        assert np.array_equal(aug_data["mask"], deserialized_aug_data["mask"])
Exemplo n.º 4
0
 def on_epoch_end(self, trainer, pl_module):
     epoch = trainer.current_epoch
     datamodule = trainer.datamodule
     cfg = pl_module.cfg
     transform = pl_module.policy_model.create_transform(
         input_dtype=cfg.data.input_dtype,
         preprocessing_transforms=datamodule.get_preprocessing_transforms(),
     )
     policy_file_filepath = os.path.join(self.dirpath,
                                         f"epoch_{epoch}.json")
     A.save(transform, policy_file_filepath)
     shutil.copy2(policy_file_filepath, self.latest_policy_filepath)
     log.info(
         f"Policy is saved to {policy_file_filepath}. "
         f"{self.latest_policy_filepath} now also contains this policy.")
Exemplo n.º 5
0
def get_augmentation(save_path=None, load_path=None):
        if load_path:
            return A.load(load_path)
        else:
            aug_seq1 = A.OneOf([
                A.Rotate(limit=(-90, 90), p=1.0),
                A.Flip(p=1.0),
                A.OpticalDistortion(always_apply=False, p=1.0, distort_limit=(-0.3, 0.3), 
                                    shift_limit=(-0.05, 0.05), interpolation=3, 
                                    border_mode=3, value=(0, 0, 0), mask_value=None),
            ], p=1.0)
            aug_seq2 = A.OneOf([
                # A.ChannelDropout(always_apply=False, p=1.0, channel_drop_range=(1, 1), fill_value=0),
                A.RGBShift(r_shift_limit=15, g_shift_limit=15,
                           b_shift_limit=15, p=1.0),
                A.RandomBrightnessContrast(always_apply=False, p=1.0, brightness_limit=(
                    -0.2, 0.2), contrast_limit=(-0.2, 0.2), brightness_by_max=True)
            ], p=1.0)
            aug_seq3 = A.OneOf([
                A.GaussNoise(always_apply=False, p=1.0, var_limit=(10, 50)),
                A.ISONoise(always_apply=False, p=1.0, intensity=(
                    0.1, 1.0), color_shift=(0.01, 0.3)),
                A.MultiplicativeNoise(always_apply=False, p=1.0, multiplier=(
                    0.8, 1.6), per_channel=True, elementwise=True),
            ], p=1.0)
            aug_seq4 = A.OneOf([
                A.Equalize(always_apply=False, p=1.0,
                           mode='pil', by_channels=True),
                A.InvertImg(always_apply=False, p=1.0),
                A.MotionBlur(always_apply=False, p=1.0, blur_limit=(3, 7)),
                A.RandomFog(always_apply=False, p=1.0, 
                            fog_coef_lower=0.01, fog_coef_upper=0.2, alpha_coef=0.2)
            ], p=1.0)
            aug_seq = A.Compose([
                # A.Resize(self.img_size, self.img_size),
                # aug_seq1,
                aug_seq2,
                aug_seq3,
                aug_seq4,
                # A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                # A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ])
            # aug_path = '/home/jitesh/prj/classification/test/bolt/aug/aug_seq.json'
            if save_path:
                A.save(aug_seq, save_path)
            # loaded_transform = A.load(aug_path)
            return aug_seq
Exemplo n.º 6
0
def main():
    image = cv2.imread("cuiyan.png")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    def visualize(image):
        plt.figure(figsize=(6, 6))
        plt.axis("off")
        plt.imshow(image)
        # plt.show()

    transform = A.Compose([
        A.RandomCrop(111, 222),
        A.OneOf([A.RGBShift(), A.HueSaturationValue()]),
    ])

    random.seed(42)
    transformed = transform(image=image)
    visualize(transformed["image"])

    A.save(transform, "./transform.json")
    A.save(transform, "./transform.yml", data_format="yaml")
    pprint.pprint(A.to_dict(transform))
Exemplo n.º 7
0
    dataset = NYUDataset(depths, is_train=False, transforms=augm)

    print(dataset.RGB_frames)
    img = read_image(dataset.RGB_frames[-1])
    plt.imshow(img)
    plt.figure()
    depth = dataset.read_depth(dataset.depth_frames[-1])
    print(depth.dtype)
    print(np.max(depth))
    print(np.min(depth))

    plt.imshow(depth, cmap='Greys_r')

    data = {"image": np.array(img), "mask": depth}
    augm = strong_aug(0.9)
    A.save(augm, 'transform_prova.json')

    fig = plt.figure()
    fig = plt.figure()
    labels = []
    columns = 2
    rows = 2
    for i in range(1, columns * rows + 1):
        augmented = augm(**data)
        fig.add_subplot(rows, columns, i)
        plt.imshow(augmented["mask"], cmap='Greys_r')
        labels.append(augmented["image"])

    fig = plt.figure()

    for _i, i in enumerate(range(1, columns * rows + 1)):
Exemplo n.º 8
0
    A.Equalize(always_apply=False, p=1.0, mode='pil', by_channels=True),
    A.InvertImg(always_apply=False, p=1.0),
    A.MotionBlur(always_apply=False, p=1.0, blur_limit=(3, 7)),
    A.OpticalDistortion(always_apply=False, p=1.0, distort_limit=(-0.3, 0.3), shift_limit=(-0.05, 0.05), interpolation=0, border_mode=0, value=(0, 0, 0), mask_value=None),
    A.RandomFog(always_apply=False, p=1.0, fog_coef_lower=0.1, fog_coef_upper=0.45, alpha_coef=0.5)
    ], p=1.0)
aug_seq = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    aug_seq1,
    aug_seq2,
    aug_seq3,
    aug_seq4,
    A.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
aug_path = '/home/jitesh/prj/classification/test/bolt/aug/aug_seq.json'
A.save(aug_seq, aug_path)
loaded_transform = A.load(aug_path)
# In[6]:

class BoltDataset(Dataset):
    def __init__(self, file_list, dir, mode='train', transform = None, test_label: int=1):
        self.file_list = file_list
        self.dir = dir
        self.mode= mode
        # self.transform = transform
        self.test_label = test_label
        if self.mode == 'train':
            # print(self.file_list)
            # if 'b00' in self.file_list[0]:
            if 'b10' in self.file_list[0]:
                self.label = 0