示例#1
0
 def __init__(self, inference_utils, config, visual_logging, gpu):
     self.visual_logging = visual_logging
     self.gpu = gpu
     self.inference_utils = inference_utils
     self.cuda_utils = CudaUtils()
     self.image_utils = ImageUtils()
     self.config = config
示例#2
0
 def __init__(self, data_path):
     self.data_path = data_path
     if not os.path.isdir(data_path):
         error_message = f"Data path '{data_path}' not found."
         #TODO: check all subfolders are in place
         raise Exception(error_message)
     self.models_path = os.path.join(self.data_path, "models/")
     self.objects_path = os.path.join(self.data_path, "dataset/augmentation/")
     self.dataset_path = os.path.join(self.data_path, "dataset/")
     self.best_model_file = "model.backup"
     self.logs_path = os.path.join(self.data_path, "logs")
     self.image_download_path = os.path.join(self.data_path, "image_download")
     self.system_utils = SystemUtils()
     self.image_utils = ImageUtils()
示例#3
0
 def extractConnectedComponents(self, classifier_predictions, masks):
     connected_components = {}
     image_utils = ImageUtils()
     for batch_index in range(masks.size()[0]):
         for class_index in range(masks.size()[1]):
             probability = classifier_predictions[batch_index][class_index].data[0]
             if probability > self.threshold:
                 mask = image_utils.toNumpy(masks[batch_index][class_index].data)
                 if self.visual_logging:
                     cv2.imshow(f'Mask {self.config.classes[class_index]}', cv2.resize(mask, (mask.shape[1], mask.shape[0])) )
                     cv2.waitKey(0)
                 mask = ((mask * 255) / mask.max()).astype(np.uint8)
                 self.extractConnectedComponentsInMask(batch_index, class_index, connected_components, mask)                  
     return connected_components
示例#4
0
    def testContrastBrightnessDistorsion(self):
        image_utils = ImageUtils()
        image_distorsions = ImageDistortions()
        image_base = cv2.imread('./test/data/images/square/square_1.png')
        image_info = ImageInfo(image_base)
        expected_brigthened_image = cv2.imread(
            './test/data/images/square/distortions/square_1_brightness.png',
            cv2.IMREAD_UNCHANGED)
        image_base_with_alpha_channel = image_utils.addAlphaChannelToImage(
            image_base)
        brigthened_image = image_distorsions.distortImage(
            image_base_with_alpha_channel)
        difference_of_images = cv2.subtract(brigthened_image,
                                            expected_brigthened_image)

        self.assertTrue(np.count_nonzero(difference_of_images > 10) < 10)
示例#5
0
 def __init__(self, data_path, visual_logging, reset_model, num_epochs, batch_size, learning_rate, momentum, gpu,\
     test_samples):
     super(Trainer, self).__init__()
     self.config = DatasetConfiguration(True, data_path)
     self.data_path = data_path
     self.visual_logging = visual_logging
     self.reset_model = reset_model
     self.num_epochs = num_epochs
     self.batch_size = batch_size
     self.learning_rate = learning_rate
     self.test_loss_history = []
     self.gpu = gpu
     self.momentum = momentum
     self.test_samples = test_samples
     self.system_utils = SystemUtils()
     self.logger = self.system_utils.getLogger(self)
     self.image_utils = ImageUtils()
     self.environment = EnvironmentUtils(self.data_path)
     self.cuda_utils = CudaUtils()
 def extractConnectedComponents(self, class_index, mask):
     connected_component = None
     image_utils = ImageUtils()
     image_info = ImageInfo(mask)
     if mask.sum() > 10.00:
         mask = (mask * 255).astype(np.uint8)
         _, contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL,
                                                   cv2.CHAIN_APPROX_SIMPLE)
         x, y, w, h = cv2.boundingRect(contours[0])
         return torch.IntTensor([class_index, x, y, x + w, y + h])
示例#7
0
    def testBasicBackgroundRemoval(self):
        image_utils = ImageUtils()
        image_base = cv2.imread(
            'test/data/images/generated/augmented_image_with_overlappings/1_mask_square.png'
        )
        basic_background_remover = BasicBackgroundRemover()
        image_base_with_alpha_channel = image_utils.addAlphaChannelToImage(
            image_base)
        image_without_background = basic_background_remover.removeFlatBackgroundFromRGB(
            image_base)

        difference_of_images = cv2.subtract(image_without_background,
                                            image_base_with_alpha_channel)
        difference_of_images_inverse = cv2.subtract(
            image_base_with_alpha_channel, image_without_background)

        self.assertTrue(
            np.count_nonzero(
                difference_of_images - difference_of_images_inverse > 10) < 10)
 def __init__(self,
              is_train,
              data_path,
              visual_logging=False,
              max_samples=None,
              seed=None):
     super(DataAugmentationDataset, self).__init__()
     random.seed(seed)
     self.environment = EnvironmentUtils(data_path)
     self.system_utils = SystemUtils()
     self.visual_logging = visual_logging
     self.config = DatasetConfiguration(is_train, data_path)
     self.image_utils = ImageUtils()
     self.image_distortions = ImageDistortions()
     self.basic_background_remover = BasicBackgroundRemover()
     self.logger = self.system_utils.getLogger(self)
     self.maximum_area = constants.input_width * constants.input_height
     if max_samples is None:
         self.length = self.config.length
     else:
         self.length = min(self.config.length, max_samples)
示例#9
0
 def __init__(self, data_path, path, inferencer, result_folder,
              visual_logging, gpu):
     self.result_folder = result_folder
     self.visual_logging = visual_logging
     self.inferencer = inferencer
     self.path = path
     self.is_folder = os.path.isdir(self.path)
     self.environment = EnvironmentUtils(data_path)
     self.cuda_utils = CudaUtils()
     self.gpu = gpu
     self.model = self.cuda_utils.cudify(
         [self.environment.loadModel(self.environment.best_model_file)],
         self.gpu)[0]
     self.system_utils = SystemUtils()
     self.image_utils = ImageUtils()
示例#10
0
 def cropRefinerDataset(self, connected_components_predicted, predicted_masks, input_image):
     image_utils = ImageUtils()
     connected_components = {}
     for batch_index in connected_components_predicted.keys():
         for class_index in connected_components_predicted[batch_index].keys():
             for predicted_connected_component in connected_components_predicted[batch_index][class_index]:
                 mask = predicted_connected_component['mask']
                 bounding_box_object = predicted_connected_component['bounding_box']
                 bounding_box_object_1 = bounding_box_object.resize(predicted_masks.size()[3], predicted_masks.size()[2], input_image.size()[3], input_image.size()[2])
                 if bounding_box_object.area > 1:
                     if not batch_index in connected_components:
                         connected_components[batch_index] = {}
                     if not class_index in connected_components[batch_index]:
                         connected_components[batch_index][class_index] = []
                     object = { 'predicted_mask': predicted_masks[:, class_index:class_index+1, :, :],
                         'bounding_box': bounding_box_object_1 }
                     connected_components[batch_index][class_index].append(object)
     return connected_components
示例#11
0
    def testBasicBackgroundRemovalForImageWithHole(self):
        image_utils = ImageUtils()
        image_base = cv2.imread(
            './test/data/images/square/square_5_with_hole_and_background.png')
        image_result = cv2.imread(
            'test/data/images/generated/background_removed/square_5_background_removed.png',
            cv2.IMREAD_UNCHANGED)
        basic_background_remover = BasicBackgroundRemover()
        image_without_background = basic_background_remover.removeFlatBackgroundFromRGB(
            image_base)

        difference_of_images = cv2.subtract(image_without_background,
                                            image_result)
        difference_of_images_inverse = cv2.subtract(image_result,
                                                    image_without_background)
        self.assertTrue(
            np.count_nonzero((difference_of_images +
                              difference_of_images_inverse) > 10) < 10)
示例#12
0
class ImageDistortions():
    def __init__(self):
        self.image_utils = ImageUtils()

    def _getScaleValue(self):
        probability = random.uniform(0, 1)
        if probability < 0.7:
            scale = random.uniform(0.8, 1.0)
        elif 0.8 <= probability < 0.95:
            scale = random.uniform(0.5, 0.8)
        else:
            scale = random.uniform(0.1, 0.5)
        return scale

    def getScaleParams(self, original_width, original_height):
        proportion_y = original_width / original_height
        max_scale_x = constants.input_width / original_width
        max_scale_y = constants.input_height / original_height
        scale_x = self._getScaleValue()
        scale_y = self._getScaleValue()
        if scale_x > max_scale_x:
            scale_x = scale_x * max_scale_x * proportion_y
            scale_y = scale_y * max_scale_y
        elif scale_y > max_scale_y:
            scale_y = scale_y * max_scale_y / proportion_y
            scale_x = scale_x * max_scale_x
        return scale_x, scale_y

    def getScaleMatrix(self, scale_x, scale_y):
        scale_matrix = [
            [scale_x,   0,       0],
            [0,         scale_y, 0],
            [0,         0,       1]
        ]
        return np.array(scale_matrix)

    def getScaledRotoTranslationMatrix(self, scale_x, scale_y, original_width, original_height):
        angle_probability = random.uniform(0, 1)
        if angle_probability < 0.80:
            angle = random.uniform(-15, 15)
        elif 0.80 <= angle_probability < 0.90:
            angle = random.uniform(-30, 30)
        elif 0.90 <= angle_probability < 0.98:
            angle = random.uniform(-90, 90)
        else:
            angle = random.uniform(-180, 180)
        angle_radiants = math.radians(angle)
        cos_angle = math.cos(angle_radiants)
        sin_angle = math.sin(angle_radiants)
        image_width = original_height * math.fabs(sin_angle)  + original_width * math.fabs(cos_angle)
        image_height = original_height * math.fabs(cos_angle) + original_width * math.fabs(sin_angle)

        # center the rotated image on the top left angle of the image
        percentage_out_of_image = 0.2
        base_x_translation = ( (1 - cos_angle ) * original_width / 2 ) - ( sin_angle       * original_height / 2 ) + ( image_width / 2 - original_width / 2 )
        base_y_translation = ( sin_angle       *  original_width / 2 ) + ( (1 - cos_angle) * original_height / 2 ) + ( image_height / 2 - original_height / 2 )
        random_x_translation = random.uniform(-percentage_out_of_image * image_width, constants.input_width / scale_x - ( (image_width * (1 - percentage_out_of_image))  ) )
        random_y_translation = random.uniform(-percentage_out_of_image * image_height, constants.input_height / scale_y - ( (image_height * (1 - percentage_out_of_image)) ) )
        rotation_matrix = np.array([
	    [cos_angle,    sin_angle],
	    [-sin_angle,   cos_angle]
	])
        rototranslation_matrix  = np.zeros((3, 3))
        rototranslation_matrix[0:2,0:2] = rotation_matrix
        rototranslation_matrix[:, 2:3] = [
	    [base_x_translation + random_x_translation],
	    [base_y_translation + random_y_translation],
	    [1]
	]
        return rototranslation_matrix

    def getPerspectiveMatrix(self):
        perspective_probability = random.uniform(0, 1)
        if perspective_probability < 0.7:
            perspective_x = random.uniform(-0.00001, 0.00001)
            perspective_y = random.uniform(-0.00001, 0.00001)
        elif 0.7 <= perspective_probability < 0.95:
            perspective_x = random.uniform(-0.0001, 0.0002)
            perspective_y = random.uniform(-0.0001, 0.0002)
        else:
            perspective_x = random.uniform(-0.0003, 0.0005)
            perspective_y = random.uniform(-0.0003, 0.0005)

        perspective_matrix = [
          [1, 			0, 		0],
          [0,                   1, 		0],
          [perspective_x,       perspective_y,	1]
        ]
        return np.matrix(perspective_matrix)

    def applyContrastAndBrightness(self, image):
        channels = ImageInfo(image).channels
        distort = bool(random.getrandbits(1))
        if distort:
            contrast_parameter = random.uniform(0.1, 2.0)
            image = cv2.merge([ cv2.multiply(image[:, :, channel_index], contrast_parameter) for channel_index in range(channels)])
        distort = bool(random.getrandbits(1))
        if distort:
            brightness = random.uniform(-int(np.mean(image) / 2.0), int(np.mean(image) / 2.0))
            image = cv2.merge([ cv2.add(image[:, :, channel_index], brightness) for channel_index in range(channels) ])
        return image

    def changeContrastAndBrightnessToImage(self, image):
        transformed_image_with_color_noise = self.applyContrastAndBrightness(image[:, :, 0:3])
        return transformed_image_with_color_noise

    def distortImage(self, item_image):
        item_image_info = ImageInfo(item_image)
        (scale_x, scale_y) = self.getScaleParams(item_image_info.width, item_image_info.height)
        scale_matrix = self.getScaleMatrix(scale_x, scale_y)
        rototranslation_matrix = self.getScaledRotoTranslationMatrix(scale_x, scale_y, item_image_info.width, item_image_info.height)
        perspective_matrix = self.getPerspectiveMatrix()
        homography_matrix = np.dot(scale_matrix, np.dot(rototranslation_matrix, perspective_matrix))

        transformed_image = cv2.warpPerspective( item_image, homography_matrix, (constants.input_width, constants.input_height) )
        if item_image_info.channels == 4:
            alpha_channel = transformed_image[:, :, 3]
        if item_image_info.channels == 4:
            return self.image_utils.addAlphaChannelToImage(transformed_image, alpha_channel)
        else:
            return transformed_image
class DataAugmentationDataset(Dataset):
    def __init__(self,
                 is_train,
                 data_path,
                 visual_logging=False,
                 max_samples=None,
                 seed=None):
        super(DataAugmentationDataset, self).__init__()
        random.seed(seed)
        self.environment = EnvironmentUtils(data_path)
        self.system_utils = SystemUtils()
        self.visual_logging = visual_logging
        self.config = DatasetConfiguration(is_train, data_path)
        self.image_utils = ImageUtils()
        self.image_distortions = ImageDistortions()
        self.basic_background_remover = BasicBackgroundRemover()
        self.logger = self.system_utils.getLogger(self)
        self.maximum_area = constants.input_width * constants.input_height
        if max_samples is None:
            self.length = self.config.length
        else:
            self.length = min(self.config.length, max_samples)

    def __len__(self):
        return self.length

    def cutImageBackground(self, image):
        image_info = ImageInfo(image)
        scale_x = constants.input_width / image_info.width
        scale_y = constants.input_height / image_info.height
        if scale_x > scale_y:
            image = cv2.resize(
                image,
                (constants.input_width, int(image_info.height * scale_x)))
        else:
            image = cv2.resize(
                image,
                (int(image_info.width * scale_y), constants.input_height))
        image_info = ImageInfo(image)

        x_seed = random.uniform(0,
                                1) * (image_info.width - constants.input_width)

        initial_width = int(0 + x_seed)
        final_width = int(x_seed + constants.input_width)

        y_seed = random.uniform(
            0, 1) * (image_info.height - constants.input_height)
        initial_height = int(0 + y_seed)
        final_height = int(y_seed + constants.input_height)
        return image[initial_height:final_height, initial_width:final_width, :]

    def coverInputDimensions(self, image):
        image_info = ImageInfo(image)
        if image_info.width < constants.input_width:
            image = cv2.resize(
                image, (constants.input_width,
                        int(constants.input_width / image_info.aspect_ratio)))
        image_info = ImageInfo(image)
        if image_info.height < constants.input_height:
            image = cv2.resize(
                image, (int(constants.input_height * image_info.aspect_ratio),
                        constants.input_height))
        return image

    def randomBackground(self):
        use_flat_background = bool(random.getrandbits(1))
        if use_flat_background:
            background_image = np.ones(
                (constants.input_height, constants.input_width, 3),
                dtype=np.uint8) * 255
            use_white_background = bool(random.getrandbits(1))
            if not use_white_background:
                # defalut to random background
                for channel in range(0, 3):
                    random_channel_value = random.uniform(0, 256)
                    background_image[:, :, channel] = random_channel_value
            return background_image
        else:
            background_index = np.random.randint(
                len(self.config.objects[constants.background_label]),
                size=1)[0] % len(
                    self.config.objects[constants.background_label])
            background_image = cv2.imread(
                self.config.objects[constants.background_label]
                [background_index], cv2.IMREAD_COLOR)
            if background_image is not None and len(
                    background_image.shape) == 3:
                background = self.cutImageBackground(background_image)
                background = self.coverInputDimensions(background)
                background = self.applyRandomBackgroundObjects(background)
                return background
            else:
                if self.config.remove_corrupted_files:
                    self.logger.warning(
                        f'Removing corrupted image {self.config.objects[constants.background_label][background_index]}'
                    )
                    self.system_utils.rm(self.config.objects[
                        constants.background_label][background_index])
                raise ValueError(
                    f'Could not load background image {self.config.objects[constants.background_label][background_index]}'
                )

    def imageWithinInputDimensions(self, image):
        image_info = ImageInfo(image)
        if image_info.width > constants.input_width:
            image = cv2.resize(
                image, (constants.input_width,
                        int(constants.input_width / image_info.aspect_ratio)))
        image_info = ImageInfo(image)
        if image_info.height > constants.input_height:
            image = cv2.resize(
                image, (int(constants.input_height * image_info.aspect_ratio),
                        constants.input_height))
        return image

    def objectInClass(self, index, class_index, background_classes=False):
        if not background_classes:
            class_label = self.config.classes[class_index]
        else:
            class_label = self.config.background_classes[class_index]
        object_index = index % len(self.config.objects[class_label])
        current_object = cv2.imread(
            self.config.objects[class_label][object_index],
            cv2.IMREAD_UNCHANGED)

        if current_object is not None:
            current_object = self.basic_background_remover.removeFlatBackgroundFromRGB(
                current_object)
            current_object = self.imageWithinInputDimensions(current_object)
            return current_object
        else:
            if self.config.remove_corrupted_files:
                self.logger.warning(
                    f'Removing corrupted image {self.config.objects[class_label][object_index]}'
                )
                self.system_utils.rm(
                    self.config.objects[class_label][object_index])
            raise ValueError(
                f'Could not load object of class {class_label} in index {object_index}: {self.config.objects[class_label][object_index]}'
            )

    def objectsInClass(self, index, class_index, count):
        class_indexes_and_objects = [(class_index, self.system_utils.tryToRun(lambda : self.objectInClass(index + current_object_in_class, class_index), \
            lambda result: result is not None, \
            constants.max_image_retrieval_attempts)) for current_object_in_class in range(count)]
        return list(itertools.chain(*class_indexes_and_objects))

    def subtractSubMaskFromMainMask(self, all_masks, sub_mask, object_index):
        all_masks[:, :, object_index:object_index + 1] = cv2.subtract(
            all_masks[:, :, object_index:object_index + 1],
            sub_mask[:, :]).reshape(sub_mask.shape)

    def addSubMaskToMainMask(self, all_masks, sub_mask, object_index):
        all_masks[:, :, object_index:object_index + 1] = cv2.add(
            all_masks[:, :, object_index:object_index + 1],
            sub_mask[:, :]).reshape(sub_mask.shape)

    def getRandomClassIndexToCount(self, random_number_of_objects):
        class_index_to_count = [0] * len(self.config.classes)
        if random.random() > self.config.probability_no_objects:
            if random_number_of_objects is None:
                random_number_of_objects = random.randint(
                    1,
                    min(self.config.max_classes_per_image,
                        len(self.config.classes)) *
                    self.config.max_objects_per_class)
            while sum(class_index_to_count) < random_number_of_objects:
                random_class_index = random.choice(
                    list(range(0, len(self.config.classes))))
                if class_index_to_count[
                        random_class_index] < self.config.max_objects_per_class:
                    class_index_to_count[
                        random_class_index] = class_index_to_count[
                            random_class_index] + 1
        return class_index_to_count

    def applyRandomBackgroundObjects(self, background):
        if len(self.config.background_classes
               ) > 0 and self.config.max_background_objects_per_image > 0:
            size_random_background_objects = random.randint(
                1, self.config.max_background_objects_per_image)
            background_object_images = []
            for _ in range(size_random_background_objects):
                random_background_class_index = random.randint(
                    0,
                    len(self.config.background_classes) - 1)
                random_background_class = self.config.background_classes[
                    random_background_class_index]
                random_object_index = random.randint(
                    0,
                    len(self.config.objects[random_background_class]) - 1)
                background_object_image = self.objectInClass(
                    random_object_index,
                    random_background_class_index,
                    background_classes=True)
                distorted_background_object = self.image_distortions.distortImage(
                    background_object_image)
                background, _ = self.image_utils.pasteRGBAimageIntoRGBimage(
                    distorted_background_object, background, 0, 0)
            return background
        else:
            return background

    def extractConnectedComponents(self, class_index, mask):
        connected_component = None
        image_utils = ImageUtils()
        image_info = ImageInfo(mask)
        if mask.sum() > 10.00:
            mask = (mask * 255).astype(np.uint8)
            _, contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL,
                                                      cv2.CHAIN_APPROX_SIMPLE)
            x, y, w, h = cv2.boundingRect(contours[0])
            return torch.IntTensor([class_index, x, y, x + w, y + h])

    def generateAugmentedImage(self, index, random_number_of_objects=None):
        class_index_to_count = self.getRandomClassIndexToCount(
            random_number_of_objects)
        class_indexes_and_objects = [
            self.objectsInClass(index, class_index, count)
            for class_index, count in enumerate(class_index_to_count)
            if count > 0
        ]
        random.shuffle(class_indexes_and_objects)
        input_image = self.system_utils.tryToRun(self.randomBackground, \
            lambda result: result is not None, \
            constants.max_image_retrieval_attempts)
        target_masks = self.environment.blankMasks(self.config.classes)
        original_object_areas = torch.zeros(len(self.config.classes))
        bounding_boxes = torch.zeros(
            (self.config.max_classes_per_image *
             self.config.max_objects_per_class, 5)).int()
        classes_in_input = set()
        for object_index, (
                class_index,
                class_object) in enumerate(class_indexes_and_objects):
            distorted_class_object = self.image_distortions.distortImage(
                class_object)
            bounding_box = self.extractConnectedComponents(
                class_index, distorted_class_object[:, :, 3:4])
            bounding_boxes[object_index:object_index + 1, :] = bounding_box
            original_object_areas[class_index] = original_object_areas[
                class_index] + distorted_class_object[:, :, 3].sum()
            input_image, object_mask = self.image_utils.pasteRGBAimageIntoRGBimage(
                distorted_class_object, input_image, 0, 0)
            self.addSubMaskToMainMask(target_masks, object_mask, class_index)
            classes_in_input.add(class_index)
        if self.visual_logging:
            cv2.imshow(f'Before Distortion', input_image)
        input_image = self.image_distortions.changeContrastAndBrightnessToImage(
            input_image)
        if self.visual_logging:
            cv2.imshow(f'After Distortion', input_image)
            cv2.waitKey(0)
            cv2.destroyAllWindows()
        self.environment.storeSampleWithIndex(index, self.config.is_train,
                                              input_image, target_masks,
                                              original_object_areas,
                                              bounding_boxes, classes_in_input,
                                              self.config.classes)
        return input_image, target_masks, bounding_boxes

    def isDataSampleConsistentWithDatasetConfiguration(self, input_image,
                                                       target_masks,
                                                       bounding_boxes):
        if input_image is not None and target_masks is not None and bounding_boxes is not None:
            if target_masks.shape[2] is len(self.config.classes):
                is_consistent = True
                for bounding_box_index in range(bounding_boxes.size()[0]):
                    is_consistent = is_consistent and bounding_boxes[
                        bounding_box_index][0] < len(self.config.classes)
                return is_consistent
        return False

    def __getitem__(self, index, max_attempts=10):
        (input_image, target_masks, bounding_boxes) = None, None, None
        if np.random.uniform(0, 1,
                             1)[0] <= self.config.probability_using_cache:
            try:
                input_image, target_masks, original_object_areas, bounding_boxes = self.environment.getSampleWithIndex(
                    index, self.config.is_train, self.config.classes)
            except BaseException as e:
                sys.stderr.write(traceback.format_exc())
        if not self.isDataSampleConsistentWithDatasetConfiguration(
                input_image, target_masks, bounding_boxes):
            index_path = self.environment.indexPath(index,
                                                    self.config.is_train)
            self.system_utils.rm(index_path)
            current_attempt = 0
            while (not self.isDataSampleConsistentWithDatasetConfiguration(
                    input_image, target_masks,
                    bounding_boxes)) and current_attempt < max_attempts:
                try:
                    input_image, target_masks, bounding_boxes = self.generateAugmentedImage(
                        index)
                except BaseException as e:
                    sys.stderr.write(traceback.format_exc())
                    current_attempt = current_attempt + 1
                    index = index + 1
            if input_image is None or target_masks is None:
                raise ValueError(
                    f'There is a major problem during data sampling loading images. Please check error messages above.'
                )
        input_image = transforms.ToTensor()(input_image)
        target_masks = transforms.ToTensor()(target_masks)
        return input_image, target_masks, bounding_boxes
class BasicBackgroundRemover():
    def __init__(self):
        self.image_utils = ImageUtils()
        self.object_area_threshold = 0.05

    def findContours(self, filtered_image):
        _, contours, hierarchy = cv2.findContours(filtered_image,
                                                  cv2.RETR_TREE,
                                                  cv2.CHAIN_APPROX_NONE)
        areas = [cv2.contourArea(c) for c in contours]
        if len(areas) > 0:
            biggest_area_index = np.argmax(areas)
            areas_and_indexes = [
                (area, index) for index, area in enumerate(areas) if area /
                areas[biggest_area_index] > self.object_area_threshold
            ]
            sorted_areas_and_indexes = sorted(
                areas_and_indexes,
                key=lambda area_index_tuple: area_index_tuple[0],
                reverse=True)
            area_and_contours = [(area, contours[index])
                                 for (area, index) in sorted_areas_and_indexes]

            return [(area, cv2.approxPolyDP(c, 3, True))
                    for area, c in area_and_contours]
        else:
            return []

    def applySobelFilter(self, image):
        def sobel(level):
            sobel_horizontal = cv2.Sobel(level, cv2.CV_16S, 1, 0, ksize=3)
            sobel_vertical = cv2.Sobel(level, cv2.CV_16S, 0, 1, ksize=3)
            sobel_response = np.hypot(sobel_horizontal, sobel_vertical)
            sobel_response[sobel_response > 255] = 255
            return sobel_response

        if len(image.shape) == 2:
            sobel_image = sobel(image)
        else:
            sobel_image = np.max(np.array([
                sobel(image[:, :, 0]),
                sobel(image[:, :, 1]),
                sobel(image[:, :, 2])
            ]),
                                 axis=0)
        mean = np.mean(sobel_image)
        sobel_image[sobel_image <= mean] = 0
        sobel_image = sobel_image.astype(np.uint8)
        return sobel_image

    def detectAndRemoveBackgroundColor(self, image):
        image_info = ImageInfo(image)
        color_borders = [
            image[0][0], image[image_info.height - 1][0],
            image[0][image_info.width - 1],
            image[image_info.height - 1][image_info.width - 1]
        ]
        result = 255
        for color_border in color_borders:
            result = cv2.subtract(
                result,
                cv2.inRange(image, cv2.subtract(color_border, 2),
                            cv2.add(color_border, 1)))
        result = cv2.GaussianBlur(result, (5, 5), 3)
        result = cv2.erode(result, None, iterations=1)
        result = cv2.morphologyEx(result, cv2.MORPH_OPEN, None, iterations=3)
        areas_and_contours = self.findContours(result)
        cv2.drawContours(result, [areas_and_contours[0][1]], -1, (255), -1)
        for i in range(1, len(areas_and_contours)):
            (area, contour) = areas_and_contours[i]
            if area / areas_and_contours[0][0] > self.object_area_threshold:
                color = 0
            else:
                color = 255
            cv2.drawContours(result, [areas_and_contours[i][1]], -1, color, -1)
        return result

    def removeBackgroundInsideMainObject(self, original_image, contours, mask):
        original_image_info = ImageInfo(original_image)
        biggest_area, _ = contours[0]
        background_subtracted_mask = self.detectAndRemoveBackgroundColor(
            original_image)
        for contour_index in range(1, len(contours)):
            (area, contour) = contours[contour_index]
            first_pixel_in_contour = contour[0][0]
            first_pixel_in_contour = (first_pixel_in_contour[1],
                                      first_pixel_in_contour[0])
            # if the contour is at least 50% the size of the biggest element => it iss probably another object if it is not overlapping
            if area / biggest_area > 0.5 and not background_subtracted_mask[
                    first_pixel_in_contour] == 0:
                cv2.fillPoly(mask, [contour], 255)
            else:
                '''
                checking if the contour is a background or not. If the background_subtracted_mask has values as zero that are inside the contour, then the contour
                is part of the background
                '''
                if contour_index == 1:
                    background_subtracted_mask[background_subtracted_mask ==
                                               1] = 3
                    background_subtracted_mask[background_subtracted_mask ==
                                               2] = 3
                contour_canvas = self.image_utils.blankImage(
                    original_image_info.width, original_image_info.height)
                cv2.fillPoly(contour_canvas, [contour], 1)
                overlapping_areas = cv2.add(background_subtracted_mask,
                                            contour_canvas)
                mask[overlapping_areas == 1] = 0

    def removeFlatBackgroundFromRGB(self, image, full_computation=True):
        if (len(image.shape) < 3):
            raise ValueError(f'Color channel not found in image')
        if image.shape[2] == 4:
            '''
            if the alpha channel contains at least one pixel that is not fully white,
            then the alpha channel is truly an alpha channel
            '''
            if np.any(image[:, :, 3] != 255):
                return image
            else:
                (b, g, r, _) = cv2.split(image)
                image = cv2.merge([b, g, r])
        image_info = ImageInfo(image)
        blurred_image = cv2.GaussianBlur(image, (5, 5), 0)
        sobel_image = self.applySobelFilter(blurred_image)
        contours = self.findContours(sobel_image)
        mask = self.image_utils.blankImage(image_info.width, image_info.height)
        if len(contours) == 0:
            contours = [
                (0,
                 np.array([[0, 0], [0, image_info.height - 1],
                           [image_info.width - 1, 0],
                           [image_info.width - 1, image_info.height - 1]]))
            ]
        _, biggest_contour = contours[0]
        cv2.fillPoly(mask, [biggest_contour], 255)
        if full_computation and len(contours) > 1:
            self.removeBackgroundInsideMainObject(blurred_image, contours,
                                                  mask)
        mask = cv2.erode(mask, None, iterations=2)
        b, g, r = cv2.split(image)
        rgba = [b, g, r, mask]
        rgba = cv2.merge(rgba, 4)
        return rgba
 def __init__(self):
     self.image_utils = ImageUtils()
     self.object_area_threshold = 0.05
示例#16
0
 def __init__(self, threshold, config, visual_logging):
     self.config = config
     self.threshold = threshold
     self.visual_logging = visual_logging
     self.image_utils = ImageUtils()
示例#17
0
class Trainer():

    def __init__(self, data_path, visual_logging, reset_model, num_epochs, batch_size, learning_rate, momentum, gpu,\
        test_samples):
        super(Trainer, self).__init__()
        self.config = DatasetConfiguration(True, data_path)
        self.data_path = data_path
        self.visual_logging = visual_logging
        self.reset_model = reset_model
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.test_loss_history = []
        self.gpu = gpu
        self.momentum = momentum
        self.test_samples = test_samples
        self.system_utils = SystemUtils()
        self.logger = self.system_utils.getLogger(self)
        self.image_utils = ImageUtils()
        self.environment = EnvironmentUtils(self.data_path)
        self.cuda_utils = CudaUtils()

    def getWorkers(self):
        if self.visual_logging:
            return 0
        else:
            return 6

    def testLoss(self):
        current_test_loss = None
        dataset_test = DataAugmentationDataset(False, self.data_path,
                                               self.visual_logging,
                                               self.test_samples)
        test_loader = torch.utils.data.DataLoader(
            dataset=dataset_test,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.getWorkers())
        total_loss_average = 0.0
        loss_mask_average = 0.0
        loss_refiner_average = 0.0
        loss_classifier_average = 0.0
        iterations = 0.0
        for i, (input_images, target_mask,
                bounding_boxes) in enumerate(test_loader):
            input_images, target_mask = self.cuda_utils.toVariable(
                self.cuda_utils.cudify([input_images, target_mask], self.gpu))
            if self.visual_logging:
                self.visualLoggingDataset(input_images, target_mask)
            total_loss, loss_mask, loss_refiner, loss_classifier = self.computeLoss(
                input_images, target_mask, bounding_boxes)
            total_loss_average = total_loss_average + total_loss[0].data[0]
            loss_mask_average = loss_mask_average + loss_mask[0].data[0]
            loss_refiner_average = loss_refiner_average + loss_refiner[0].data[
                0]
            loss_classifier_average = loss_classifier_average + loss_classifier[
                0].data[0]
            iterations = iterations + 1.0
        total_loss_average = total_loss_average / iterations
        loss_mask_average = loss_mask_average / iterations
        loss_refiner_average = loss_refiner_average / iterations
        loss_classifier_average = loss_classifier_average / iterations
        self.log(
            f'Test Loss -- Total Loss: {total_loss_average:{1}.{4}} -- Classifier Loss: {loss_classifier_average:{1}.{4}} -- Mask Loss: {loss_mask_average:{1}.{4}} -- Refined Mask Loss: {loss_refiner_average:{1}.{4}}'
        )
        self.test_loss_history.append(total_loss_average)
        return loss_refiner_average

    def log(self, text):
        self.logger.info(f"{datetime.datetime.utcnow()} -- {text}")

    def testAndSaveIfImproved(self, best_test_loss):
        average_current_test_loss = self.testLoss()
        if average_current_test_loss < best_test_loss:
            self.log(
                f"Model Improved. Previous Best Test Loss {best_test_loss:{1}.{4}} | Current Best Test Loss  {average_current_test_loss:{1}.{4}} | Improvement Change: {(100.0 * (best_test_loss - average_current_test_loss) / average_current_test_loss):{1}.{4}} %"
            )
            best_test_loss = average_current_test_loss
            self.log(f"Saving model...")
            self.environment.saveModel(self.model,
                                       self.environment.best_model_file)
            self.log(f"...model saved")
        else:
            self.log(
                f"Model did *NOT* Improve. Current Best Test Loss {best_test_loss:{1}.{4}} | Current Test Loss {average_current_test_loss:{1}.{4}} | Improvement Change: {(100.0 * (best_test_loss - average_current_test_loss) / average_current_test_loss):{1}.{4}} %"
            )
        return best_test_loss

    def loadModel(self):
        model = None
        if not self.reset_model:
            model = self.environment.loadModel(
                self.environment.best_model_file)
        else:
            model = Model(len(self.config.classes), 32, 5, 15)
        self.log(model)
        return model

    def logBatch(self, target_mask, title):
        if self.visual_logging:
            for current_index in range(target_mask.size()[0]):
                for current_class in range(len(self.config.classes)):
                    cv2.imshow(
                        f'{title} {current_index}/"{self.config.classes[current_class]}".',
                        self.image_utils.toNumpy(
                            target_mask.data[current_index][current_class]))
            cv2.waitKey(0)
            cv2.destroyAllWindows()

    def visualLoggingDataset(self, input_images, target_mask):
        if self.visual_logging:
            for current_index in range(input_images.size()[0]):
                cv2.imshow(
                    f'Input {current_index}',
                    self.image_utils.toNumpy(input_images.data[current_index]))
                for current_class_index in range(
                        target_mask[current_index].size()[0]):
                    cv2.imshow(
                        f'Target {current_index}/{self.config.classes[current_class_index]}',
                        self.image_utils.toNumpy(
                            target_mask.data[current_index]
                            [current_class_index]))
            cv2.waitKey(0)
            cv2.destroyAllWindows()

    def visualLoggingOutput(self, network_output, target_mask_scaled):
        if self.visual_logging:
            classes, object_found, mask_scaled, mask, roi_align, bounding_boxes = network_output
            current_found_index = 0
            for current_index in range(mask_scaled.size()[0]):
                for current_class in range(len(self.config.classes)):
                    cv2.imshow(
                        f'Target {current_index}/"{self.config.classes[current_class]}".',
                        self.image_utils.toNumpy(
                            target_mask_scaled.data[current_index]
                            [current_class]))
                    cv2.imshow(
                        f'Output {current_index}/"{self.config.classes[current_class]}".',
                        self.image_utils.toNumpy(
                            mask_scaled.data[current_index][current_class]))
            cv2.waitKey(0)
            cv2.destroyAllWindows()

    def logLoss(self, total_loss, loss_mask, loss_refiner, loss_classifier,
                epoch, train_dataset_index, dataset_train):
        self.log(
            f'Epoch [{epoch+1}/{self.num_epochs}] -- Iter [{train_dataset_index+1}/{math.ceil(len(dataset_train)/self.batch_size)}] --  Total Loss: {total_loss.data[0]:{1}.{4}} -- Classifier Loss: {loss_classifier.data[0]:{1}.{4}} -- Mask Loss: {loss_mask.data[0]:{1}.{4}} -- Refined Mask Loss: {loss_refiner.data[0]:{1}.{4}}'
        )

    def buildOptimizer(self):
        optimizer = torch.optim.SGD([{
            'params': self.model.parameters(),
            'lr': self.learning_rate
        }],
                                    momentum=self.momentum,
                                    nesterov=True)
        return optimizer

    def logRefiner(self, refiner_input_image, target_mask, predicted_mask,
                   predicted_refined_mask, class_index):
        if self.visual_logging:
            cv2.imshow(f'Refiner Input.',
                       self.image_utils.toNumpy(refiner_input_image.squeeze()))
            cv2.waitKey(0)
            cv2.destroyAllWindows()
        if self.visual_logging:
            cv2.imshow(
                f'Refiner Target Mask "{self.config.classes[class_index]}".',
                self.image_utils.toNumpy(target_mask.data.squeeze()))
            cv2.imshow(
                f'Refiner Predicted Mask "{self.config.classes[class_index]}".',
                self.image_utils.toNumpy(
                    torch.nn.Upsample(
                        size=(target_mask.size()[2], target_mask.size()[3]),
                        mode='bilinear')(predicted_mask).data.squeeze()))
            #cv2.imshow(f'Refiner "{self.config.classes[class_index]}".', self.image_utils.toNumpy(input_image.data.squeeze()))
            cv2.waitKey(0)
            cv2.destroyAllWindows()
        if self.visual_logging:
            cv2.imshow(
                f'Refiner Predicted Mask "{self.config.classes[class_index]}".',
                self.image_utils.toNumpy(
                    predicted_refined_mask.data.squeeze()))
            cv2.waitKey(0)
            cv2.destroyAllWindows()

    def balancedLoss(self, predictions, targets):
        foreground = Variable(targets.data, requires_grad=False)
        background = Variable((1.0 - targets).data, requires_grad=False)
        foreground_loss = nn.L1Loss(size_average=True, reduce=False)(
            foreground * predictions,
            foreground * targets).mean() / (foreground.mean() + 1e-10)
        background_loss = nn.L1Loss(size_average=True, reduce=False)(
            background * predictions,
            background * targets).mean() / (background.mean() + 1e-10)
        if self.visual_logging and len(targets.size()) == 4:
            self.logBatch(foreground, "Tar.Fore.")
            self.logBatch(
                nn.L1Loss(size_average=False,
                          reduce=False)(foreground * predictions,
                                        foreground * targets), "Fore loss")
            self.logBatch(background, "Tar.Back.")
            self.logBatch(
                nn.L1Loss(size_average=False,
                          reduce=False)(background * predictions,
                                        background * targets), "Back loss")
        return (foreground_loss + background_loss) / 2.0

    def computeLoss(self, input_images, target_mask, bounding_boxes):
        predicted_masks, mask_embeddings, embeddings_merged, embeddings_2, embeddings_4, embeddings_8 = self.model.forward(
            input_images)
        self.logBatch(predicted_masks, "Predict")
        target_mask_scaled_16 = (nn.AvgPool2d(16)(target_mask) > 0.5).float()
        self.logBatch(target_mask_scaled_16, "Target")
        loss_mask = self.balancedLoss(predicted_masks,
                                      target_mask_scaled_16) * 0.5
        classifier_predictions = self.model.classifiers(
            [predicted_masks, embeddings_merged])
        classifier_targets = (target_mask.view(target_mask.size()[0],
                                               target_mask.size()[1],
                                               -1).sum(2) > 0).float()
        loss_classifier = self.balancedLoss(classifier_predictions,
                                            classifier_targets) * 0.1
        self.logBatch(mask_embeddings[:, 0, :, :, :].squeeze(1), "Embed")
        predicted_refined_mask = self.model.mask_refiners([
            input_images.size(), predicted_masks, mask_embeddings,
            embeddings_merged, embeddings_2, embeddings_4, embeddings_8
        ])
        self.logBatch(predicted_refined_mask, "Predict")
        target_mask_scaled_2 = (nn.AvgPool2d(2)(target_mask) > 0.5).float()
        self.logBatch(target_mask_scaled_2, "Target")
        loss_refiner = self.balancedLoss(predicted_refined_mask,
                                         target_mask_scaled_2)
        total_loss = loss_mask + loss_refiner + loss_classifier
        return total_loss, loss_mask, loss_refiner, loss_classifier

    def train(self):
        self.model = self.cuda_utils.cudify([self.loadModel()], self.gpu)[0]
        best_test_loss = self.testLoss()
        self.log(f"Initial Test Loss {best_test_loss:{1}.{4}} ")
        optimizer = self.buildOptimizer()
        for epoch in range(self.num_epochs):
            self.log(f"Epoch {epoch}")
            dataset_train = DataAugmentationDataset(True, self.data_path,
                                                    self.visual_logging)
            train_loader = torch.utils.data.DataLoader(
                dataset=dataset_train,
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=self.getWorkers())
            for train_dataset_index, (
                    input_images, target_mask,
                    bounding_boxes) in enumerate(train_loader):
                sys.stdout.flush()
                input_images, target_mask = self.cuda_utils.toVariable(
                    self.cuda_utils.cudify([input_images, target_mask],
                                           self.gpu))
                self.visualLoggingDataset(input_images, target_mask)
                optimizer.zero_grad()
                total_loss, loss_mask, loss_refiner, loss_classifier = self.computeLoss(
                    input_images, target_mask, bounding_boxes)
                total_loss.backward()
                optimizer.step()
                self.logLoss(total_loss, loss_mask, loss_refiner,
                             loss_classifier, epoch, train_dataset_index,
                             dataset_train)
                if (train_dataset_index + 1) % 1000 is 0:
                    best_test_loss = self.testAndSaveIfImproved(best_test_loss)

            self.environment.saveModel(self.model, f"{(epoch + 1)}.backup")
            best_test_loss = self.testAndSaveIfImproved(best_test_loss)
示例#18
0
class EnvironmentUtils():

    def __init__(self, data_path):
        self.data_path = data_path
        if not os.path.isdir(data_path):
            error_message = f"Data path '{data_path}' not found."
            #TODO: check all subfolders are in place
            raise Exception(error_message)
        self.models_path = os.path.join(self.data_path, "models/")
        self.objects_path = os.path.join(self.data_path, "dataset/augmentation/")
        self.dataset_path = os.path.join(self.data_path, "dataset/")
        self.best_model_file = "model.backup"
        self.logs_path = os.path.join(self.data_path, "logs")
        self.image_download_path = os.path.join(self.data_path, "image_download")
        self.system_utils = SystemUtils()
        self.image_utils = ImageUtils()

    def objectsFolder(self, class_name, is_train):
        return os.path.join(os.path.join(self.objects_path, constants.datasetType(is_train), class_name))

    def classesInDatasetFolder(self, is_train):
        base_path = os.path.join(self.objects_path, constants.datasetType(is_train))
        classes = [class_directory for class_directory in os.listdir( base_path )]
        class_paths = []
        for class_directory in classes:
            class_path = os.path.join(base_path, class_directory)
            if os.path.isdir(class_path):
               class_paths.append(class_path)
        return list(zip(classes, class_paths))

    def loadModelFromPath(self, path):
        return torch.load(path, map_location=lambda storage, loc: storage)

    def saveModel(self, mask_detector_model, name):
        path = os.path.join(self.models_path, name)
        torch.save(mask_detector_model.float(), path)

    def inputFilenamePath(self, index_path):
        index_filename_path = os.path.join(index_path, constants.dataset_input_filename)
        return index_filename_path

    def boundingBoxesFilenamePath(self, index_path):
        index_filename_path = os.path.join(index_path, constants.bounding_boxes_filename)
        return index_filename_path

    def indexPath(self, index, is_train, clean_dir=False):
        dataset_type = constants.datasetType(is_train)
        index_path = os.path.join(os.path.join(self.dataset_path, dataset_type), f"{index}")
        self.system_utils.makeDirIfNotExists(index_path)
        if clean_dir:
            self.system_utils.removeFilesFromDir(index_path)
        return index_path

    def blankMasks(self, classes):
        return self.image_utils.blankImage(constants.input_width, constants.input_height, len(classes))

    def objectsInImage(self, classes):
        objects_in_image = torch.FloatTensor(len(classes))
        objects_in_image.zero_()
        return objects_in_image

    def _retrieveAlphaMasksAndObjects(self, alpha_mask_image_paths, classes, index_path):
        target_masks = self.blankMasks(classes)
        objects_in_image = self.objectsInImage(classes)
        for alpha_image_path in alpha_mask_image_paths:
            splitted_alpha_file_path = re.sub(f'{constants.object_ext}$', '', alpha_image_path).split(constants.dataset_mask_prefix)
            (_, class_name) = splitted_alpha_file_path
            if class_name in classes:
                class_index = classes.index(class_name)
                mask_class_image = cv2.imread(os.path.join(index_path, alpha_image_path), cv2.IMREAD_UNCHANGED)
                if mask_class_image is None:
                    return None
                else:
                    mask_class_image = mask_class_image.reshape(mask_class_image.shape[0], mask_class_image.shape[1], 1)
                    target_masks[:, :, class_index : class_index + 1] = mask_class_image[:, :]
            else:
                return None    
        return target_masks

    def getSampleWithIndex(self, index, is_train, classes):
        index_path = self.indexPath(index, is_train)
        input_filename_path = self.inputFilenamePath(index_path)
        alpha_mask_image_paths = self.system_utils.imagesInFolder(index_path, constants.dataset_mask_prefix_regex)
        if os.path.isfile(input_filename_path):
            bounding_boxes = torch.load(self.boundingBoxesFilenamePath(index_path))
            input_image, target_masks = None, None
            input_image = cv2.imread(input_filename_path, cv2.IMREAD_COLOR)
            target_masks = self._retrieveAlphaMasksAndObjects(alpha_mask_image_paths, classes, index_path)
            if target_masks is None:
                return None, None, None, None
            original_object_areas_path = self.originalObjectAreasPath(index_path)
            original_object_areas = torch.load(original_object_areas_path)
            return input_image, target_masks, original_object_areas, bounding_boxes
        else:
            return None, None, None, None

    def originalObjectAreasPath(self, index_path):
        original_object_areas_path = os.path.join(index_path, constants.dataset_original_object_areas_filename)
        return original_object_areas_path

    def storeSampleWithIndex(self, index, is_train, input_image, target_masks, original_object_areas, bounding_boxes, mask_class_indexes, classes):
        index_path = self.indexPath(index, is_train, clean_dir=True)
        for class_index in mask_class_indexes:
            class_name = classes[class_index]
            object_mask_filename = os.path.join(index_path, f'{constants.dataset_mask_prefix}{class_name}.png')
            cv2.imwrite(object_mask_filename, target_masks[:, :, class_index : class_index + 1], [cv2.IMWRITE_PNG_COMPRESSION, 9])
        torch.save(bounding_boxes.cpu(), self.boundingBoxesFilenamePath(index_path))
        input_filename_path = self.inputFilenamePath(index_path)        
        cv2.imwrite(input_filename_path, input_image, [cv2.IMWRITE_PNG_COMPRESSION, 9])
        original_object_areas_path = self.originalObjectAreasPath(index_path)
        torch.save(original_object_areas.float(), original_object_areas_path)

    def loadModel(self, name):
        path = os.path.join(self.models_path, name)
        if os.path.isfile(path):
            return self.loadModelFromPath(path)
        else:
            sys.stderr.write(f"Model file not found in {path}.")
            return None
示例#19
0
 def __init__(self):
     self.image_utils = ImageUtils()
示例#20
0
class Inferencer():
    def __init__(self, inference_utils, config, visual_logging, gpu):
        self.visual_logging = visual_logging
        self.gpu = gpu
        self.inference_utils = inference_utils
        self.cuda_utils = CudaUtils()
        self.image_utils = ImageUtils()
        self.config = config

    def extractObjects(self, inference_results):
        new_inference_results = []
        for inference_result in inference_results:
            mask = self.image_utils.toNumpy(
                inference_result.mask.squeeze().data)
            original_image = self.image_utils.toNumpy(
                inference_result.image.squeeze().data)
            mask = (mask * 255).astype(np.uint8)[:, :, 0]
            blue_channel, green_channel, red_channel = cv2.split(
                original_image)
            image_cropped_with_alpha_channel = cv2.merge(((blue_channel * 255).astype(np.uint8), \
                (green_channel * 255).astype(np.uint8), (red_channel * 255).astype(np.uint8), mask))
            inference_result = InferenceResult(\
                class_label = inference_result.class_label,
                bounding_box  = inference_result.bounding_box,
                mask = inference_result.mask,
                image = image_cropped_with_alpha_channel)
            new_inference_results.append(inference_result)
        return new_inference_results

    def refine(self, refiner_dataset, mask_refiner, original_image):
        inference_results = []
        for batch_index in refiner_dataset.keys():
            for class_index in refiner_dataset[batch_index].keys():
                for connected_component in refiner_dataset[batch_index][
                        class_index]:
                    predicted_mask = connected_component['predicted_mask']
                    if self.visual_logging:
                        cv2.imshow(
                            f'Mask {self.config.classes[class_index]}',
                            self.image_utils.toNumpy(
                                predicted_mask.squeeze().data))
                        cv2.waitKey(0)
                    predicted_refined_mask = F.upsample(
                        predicted_mask,
                        size=(original_image.size()[1],
                              original_image.size()[2]),
                        mode='bilinear')
                    inference_result = InferenceResult(\
                        class_label = self.config.classes[class_index],
                        bounding_box = connected_component['bounding_box'],
                        mask = predicted_refined_mask,
                        image = original_image)
                    inference_results.append(inference_result)
        return inference_results

    def inferenceOnImage(self, model, original_input_image):
        input_image, new_height, new_width = self.image_utils.paddingScale(
            original_input_image)
        if self.visual_logging:
            cv2.imshow(f'Input padding scale', input_image)
            cv2.waitKey(0)
        input_image = Variable(transforms.ToTensor()(input_image))
        input_image = self.cuda_utils.cudify([input_image.unsqueeze(0)],
                                             self.gpu)[0]
        original_input_image = Variable(
            transforms.ToTensor()(original_input_image))
        predicted_masks, mask_embeddings, embeddings_merged, embeddings_2, embeddings_4, embeddings_8 = model.forward(
            input_image)
        classifier_predictions = model.classifiers(
            [predicted_masks, embeddings_merged])
        predicted_refined_masks = model.mask_refiners([
            input_image.size(), predicted_masks, mask_embeddings,
            embeddings_merged, embeddings_2, embeddings_4, embeddings_8
        ])
        if self.visual_logging:
            for class_index in range(predicted_masks.size()[1]):
                cv2.imshow(
                    f'Refined Mask {self.config.classes[class_index]}',
                    self.image_utils.toNumpy(
                        predicted_refined_masks[0, class_index, :, :].data))
            cv2.waitKey(0)
        connected_components_predicted = self.inference_utils.extractConnectedComponents(
            classifier_predictions, predicted_refined_masks)
        refiner_dataset = \
                self.inference_utils.cropRefinerDataset(connected_components_predicted, predicted_refined_masks, input_image)
        inference_results = self.refine(refiner_dataset, model.mask_refiners,
                                        original_input_image)
        objects = self.extractObjects(inference_results)
        return objects