Пример #1
0
    def _augment(img, depth, label):
        st = lambda x: iaa.Sometimes(0.5, x)  # NOQA
        augmentations = [
            st(iaa.WithChannels([0, 1], iaa.Multiply([1, 1.5]))),
            st(
                iaa.InColorspace(
                    'HSV',
                    children=iaa.WithChannels([1, 2], iaa.Multiply([0.5, 2])),
                )),
            (iaa.GaussianBlur(sigma=[0, 1])),
            iaa.Sometimes(0.9, iaa.Dropout(p=(0, 0.4), name='dropout')),
            # iaa.CoarseDropout(p=(0, 0.1), size_percent=0.5, name='dropout'),
            iaa.Sometimes(
                0.9,
                iaa.Affine(
                    order=1,
                    cval=0,
                    scale=1,
                    translate_px=(-96, 96),
                    rotate=(-180, 180),
                    mode='constant',
                )),
        ]
        aug = iaa.Sequential(augmentations, random_order=True)

        def activator_imgs(images, augmenter, parents, default):
            if isinstance(augmenter, iaa.Affine):
                augmenter.order = Deterministic(1)
                augmenter.cval = Deterministic(0)
            return True

        def activator_depths(images, augmenter, parents, default):
            white_lists = (iaa.Affine, iaa.Sequential, iaa.Sometimes)
            if not isinstance(augmenter, white_lists):
                return False
            if isinstance(augmenter, iaa.Affine):
                augmenter.order = Deterministic(1)
                augmenter.cval = Deterministic(0)
            return True

        def activator_lbls(images, augmenter, parents, default):
            white_lists = (iaa.Affine, iaa.Sequential, iaa.Sometimes)
            if not isinstance(augmenter, white_lists):
                return False
            if isinstance(augmenter, iaa.Affine):
                augmenter.order = Deterministic(0)
                augmenter.cval = Deterministic(-1)
            return True

        aug = aug.to_deterministic()
        img = aug.augment_image(img,
                                hooks=ia.HooksImages(activator=activator_imgs))
        depth = aug.augment_image(
            depth, hooks=ia.HooksImages(activator=activator_depths))
        label = aug.augment_image(
            label, hooks=ia.HooksImages(activator=activator_lbls))

        return img, depth, label
Пример #2
0
def augment_object_data(object_data,
                        random_state=None,
                        fit_output=True,
                        aug_color=True,
                        aug_geo=True,
                        augmentations=None,
                        random_order=False,
                        scale=(0.5, 1.0)):
    try:
        iaa.Affine(fit_output=True)
    except TypeError:
        warnings.warn(
            'Your imgaug does not support fit_output kwarg for'
            'imgaug.augmenters.Affine. Please install via'
            '\n\n\tpip install -e git+https://github.com/wkentaro/imgaug@affine_resize\n\n'  # NOQA
            'to enable it.')
        fit_output = False

    if random_state is None:
        random_state = np.random.RandomState()
    if augmentations is None:
        st = lambda x: iaa.Sometimes(0.3, x)  # NOQA
        kwargs_affine = dict(
            order=1,  # order=0 for mask
            cval=0,
            scale=scale,
            translate_px=(-16, 16),
            rotate=(-180, 180),
            shear=(-16, 16),
            mode='constant',
        )
        if fit_output:
            kwargs_affine['fit_output'] = fit_output
        augmentations = [
            st(
                iaa.WithChannels([0, 1], iaa.Multiply([1, 1.5])
                                 ) if aug_color else iaa.Noop()),
            st(
                iaa.WithColorspace('HSV',
                                   children=iaa.
                                   WithChannels([1, 2], iaa.Multiply([0.5, 2]))
                                   ) if aug_color else iaa.Noop()),
            st(
                iaa.GaussianBlur(
                    sigma=[0.0, 1.0]) if aug_color else iaa.Noop()),
            iaa.Sometimes(
                0.8,
                iaa.Affine(**kwargs_affine) if aug_geo else iaa.Noop()),
        ]
    aug = iaa.Sequential(
        augmentations,
        random_order=random_order,
        random_state=ia.copy_random_state(random_state),
    )

    def activator_imgs(images, augmenter, parents, default):
        if isinstance(augmenter, iaa.Affine):
            augmenter.order = Deterministic(1)
            augmenter.cval = Deterministic(0)
        return True

    def activator_masks(images, augmenter, parents, default):
        white_lists = (iaa.Affine, iaa.PerspectiveTransform, iaa.Sequential,
                       iaa.Sometimes)
        if not isinstance(augmenter, white_lists):
            return False
        if isinstance(augmenter, iaa.Affine):
            augmenter.order = Deterministic(0)
            augmenter.cval = Deterministic(0)
        return True

    def activator_lbls(images, augmenter, parents, default):
        white_lists = (iaa.Affine, iaa.PerspectiveTransform, iaa.Sequential,
                       iaa.Sometimes)
        if not isinstance(augmenter, white_lists):
            return False
        if isinstance(augmenter, iaa.Affine):
            augmenter.order = Deterministic(0)
            augmenter.cval = Deterministic(-1)
        return True

    for objd in object_data:
        aug = aug.to_deterministic()
        objd['img'] = aug.augment_image(
            objd['img'], hooks=ia.HooksImages(activator=activator_imgs))
        if 'mask' in objd:
            objd['mask'] = aug.augment_image(
                objd['mask'], hooks=ia.HooksImages(activator=activator_masks))
        if 'lbl' in objd:
            objd['lbl'] = aug.augment_image(
                objd['lbl'], hooks=ia.HooksImages(activator=activator_lbls))
        if 'lbl_suc' in objd:
            objd['lbl_suc'] = aug.augment_image(
                objd['lbl_suc'],
                hooks=ia.HooksImages(activator=activator_lbls))
        if 'masks' in objd:
            masks = []
            for mask in objd['masks']:
                mask = aug.augment_image(
                    mask,
                    hooks=ia.HooksImages(activator=activator_masks),
                )
                masks.append(mask)
            masks = np.asarray(masks)
            objd['masks'] = masks
            del masks
        if 'lbls' in objd:
            lbls = []
            for lbl in objd['lbls']:
                lbl = aug.augment_image(
                    lbl,
                    hooks=ia.HooksImages(activator=activator_lbls),
                )
                lbls.append(lbl)
            lbls = np.asarray(lbls)
            objd['lbls'] = lbls
            del lbls
        yield objd
def get_hooks_binmasks():
    return imgaug.HooksImages(activator=activator_binmasks)
def train():
    BATCH_SIZE = 100

    network = Network()

    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S")

    # create directory for saving models
    os.makedirs(os.path.join('save', network.description, timestamp))

    dataset = Dataset(folder='data{}_{}'.format(network.IMAGE_HEIGHT,
                                                network.IMAGE_WIDTH),
                      include_hair=False,
                      batch_size=BATCH_SIZE)

    inputs, targets = dataset.next_batch()
    print(inputs.shape, targets.shape)

    # augmentation_seq = iaa.Sequential([
    #     iaa.Crop(px=(0, 16)),  # crop images from each side by 0 to 16px (randomly chosen)
    #     iaa.Fliplr(0.5),  # horizontally flip 50% of the images
    #     iaa.GaussianBlur(sigma=(0, 2.0))  # blur images with a sigma of 0 to 3.0
    # ])

    augmentation_seq = iaa.Sequential([
        iaa.Crop(
            px=(0, 16), name="Cropper"
        ),  # crop images from each side by 0 to 16px (randomly chosen)
        iaa.Fliplr(0.5, name="Flipper"),
        iaa.GaussianBlur((0, 3.0), name="GaussianBlur"),
        iaa.Dropout(0.02, name="Dropout"),
        iaa.AdditiveGaussianNoise(scale=0.01 * 255, name="GaussianNoise"),
        iaa.Affine(translate_px={
            "x": (-network.IMAGE_HEIGHT // 3, network.IMAGE_WIDTH // 3)
        },
                   name="Affine")
    ])

    # change the activated augmenters for binary masks,
    # we only want to execute horizontal crop, flip and affine transformation
    def activator_binmasks(images, augmenter, parents, default):
        if augmenter.name in ["GaussianBlur", "Dropout", "GaussianNoise"]:
            return False
        else:
            # default value for all other augmenters
            return default

    hooks_binmasks = imgaug.HooksImages(activator=activator_binmasks)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        summary_writer = tf.summary.FileWriter('{}/{}-{}'.format(
            'logs', network.description, timestamp),
                                               graph=tf.get_default_graph())
        saver = tf.train.Saver(tf.all_variables(), max_to_keep=None)

        test_accuracies = []
        # Fit all training data
        n_epochs = 500
        global_start = time.time()
        for epoch_i in range(n_epochs):
            dataset.reset_batch_pointer()

            for batch_i in range(dataset.num_batches_in_epoch()):
                batch_num = epoch_i * dataset.num_batches_in_epoch(
                ) + batch_i + 1

                augmentation_seq_deterministic = augmentation_seq.to_deterministic(
                )

                start = time.time()
                batch_inputs, batch_targets = dataset.next_batch()
                batch_inputs = np.reshape(
                    batch_inputs, (dataset.batch_size, network.IMAGE_HEIGHT,
                                   network.IMAGE_WIDTH, 1))
                batch_targets = np.reshape(
                    batch_targets, (dataset.batch_size, network.IMAGE_HEIGHT,
                                    network.IMAGE_WIDTH, 1))

                batch_inputs = augmentation_seq_deterministic.augment_images(
                    batch_inputs)
                batch_inputs = np.multiply(batch_inputs, 1.0 / 255)

                batch_targets = augmentation_seq_deterministic.augment_images(
                    batch_targets, hooks=hooks_binmasks)

                cost, _ = sess.run(
                    [network.cost, network.train_op],
                    feed_dict={
                        network.inputs: batch_inputs,
                        network.targets: batch_targets,
                        network.is_training: True
                    })
                end = time.time()
                print('{}/{}, epoch: {}, cost: {}, batch time: {}'.format(
                    batch_num, n_epochs * dataset.num_batches_in_epoch(),
                    epoch_i, cost, end - start))

                if batch_num % 100 == 0 or batch_num == n_epochs * dataset.num_batches_in_epoch(
                ):
                    test_inputs, test_targets = dataset.test_set
                    # test_inputs, test_targets = test_inputs[:100], test_targets[:100]

                    test_inputs = np.reshape(
                        test_inputs,
                        (-1, network.IMAGE_HEIGHT, network.IMAGE_WIDTH, 1))
                    test_targets = np.reshape(
                        test_targets,
                        (-1, network.IMAGE_HEIGHT, network.IMAGE_WIDTH, 1))
                    test_inputs = np.multiply(test_inputs, 1.0 / 255)

                    print(test_inputs.shape)
                    summary, test_accuracy = sess.run(
                        [network.summaries, network.accuracy],
                        feed_dict={
                            network.inputs: test_inputs,
                            network.targets: test_targets,
                            network.is_training: False
                        })

                    summary_writer.add_summary(summary, batch_num)

                    print('Step {}, test accuracy: {}'.format(
                        batch_num, test_accuracy))
                    test_accuracies.append((test_accuracy, batch_num))
                    print("Accuracies in time: ", [
                        test_accuracies[x][0]
                        for x in range(len(test_accuracies))
                    ])
                    max_acc = max(test_accuracies)
                    print("Best accuracy: {} in batch {}".format(
                        max_acc[0], max_acc[1]))
                    print("Total time: {}".format(time.time() - global_start))

                    # Plot example reconstructions
                    n_examples = 12
                    test_inputs, test_targets = dataset.test_inputs[:
                                                                    n_examples], dataset.test_targets[:
                                                                                                      n_examples]
                    test_inputs = np.multiply(test_inputs, 1.0 / 255)

                    test_segmentation = sess.run(
                        network.segmentation_result,
                        feed_dict={
                            network.inputs:
                            np.reshape(test_inputs, [
                                n_examples, network.IMAGE_HEIGHT,
                                network.IMAGE_WIDTH, 1
                            ])
                        })

                    # Prepare the plot
                    test_plot_buf = draw_results(test_inputs, test_targets,
                                                 test_segmentation,
                                                 test_accuracy, network,
                                                 batch_num)

                    # Convert PNG buffer to TF image
                    image = tf.image.decode_png(test_plot_buf.getvalue(),
                                                channels=4)

                    # Add the batch dimension
                    image = tf.expand_dims(image, 0)

                    # Add image summary
                    image_summary_op = tf.summary.image("plot", image)

                    image_summary = sess.run(image_summary_op)
                    summary_writer.add_summary(image_summary)

                    if test_accuracy >= max_acc[0]:
                        checkpoint_path = os.path.join('save',
                                                       network.description,
                                                       timestamp, 'model.ckpt')
                        saver.save(sess,
                                   checkpoint_path,
                                   global_step=batch_num)