Exemplo n.º 1
0
def get_model(cfg, training=True):
    tf.keras.backend.set_learning_phase(training)

    model = None
    n_classes = len(cfg.CLASSES.keys())
    if cfg.model_type == "UNET":
        model = Unet(backbone_name=cfg.backbone_name,
                     input_shape=cfg.input_shape,
                     classes=n_classes,
                     activation='sigmoid' if n_classes == 1 else 'softmax',
                     weights=None,
                     encoder_weights=cfg.encoder_weights,
                     encoder_freeze=cfg.encoder_freeze,
                     encoder_features=cfg.encoder_features,
                     decoder_block_type=cfg.decoder_block_type,
                     decoder_filters=cfg.decoder_filters,
                     decoder_use_batchnorm=True)
    elif cfg.model_type == "FPN":
        model = FPN(backbone_name=cfg.backbone_name,
                    input_shape=cfg.input_shape,
                    classes=n_classes,
                    activation='sigmoid' if n_classes == 1 else 'softmax',
                    weights=None,
                    encoder_weights=cfg.encoder_weights,
                    encoder_freeze=cfg.encoder_freeze,
                    encoder_features=cfg.encoder_features)
    else:
        print("Unsupported model type!")
        exit(1)

    if cfg.pretrained_model is not None:
        model.load_weights(cfg.pretrained_model)

    return model
def save_model_to_json(modeljsonfname):
    layers = get_feature_layers(backbone_name, 4)
    if backbone_type == 'FPN':
        model = FPN(input_shape=(None, None, num_channels),
                    classes=num_mask_channels,
                    encoder_weights=encoder_weights,
                    backbone_name=backbone_name,
                    activation=act_fcn,
                    encoder_features=layers)
    elif backbone_type == 'Unet':
        model = Unet(input_shape=(None, None, num_channels),
                     classes=num_mask_channels,
                     encoder_weights=encoder_weights,
                     backbone_name=backbone_name,
                     activation=act_fcn,
                     encoder_features=layers)
    #model.summary()
    # serialize model to JSON
    model_json = model.to_json()
    with open(modeljsonfname, "w") as json_file:
        json_file.write(model_json)
Exemplo n.º 3
0
def Simple(weight_path='./models/FPN/fpn_camouflage_baseline.h5',
           img_path=test_dict["CHAMELEON"],
           target_size=(256, 256),
           batch_size=1,
           save_path="./results/FPN/CHAMELEON"):
    os.makedirs(save_path, exist_ok=True)
    print(img_path)
    data_gen_args = dict(fill_mode='nearest')
    img_datagen = ImageDataGenerator(**data_gen_args)

    test_gen = img_datagen.flow_from_directory(img_path,
                                               batch_size=batch_size,
                                               target_size=target_size,
                                               shuffle=False,
                                               class_mode=None)

    model = FPN(backbone_name='resnet50', classes=1, activation='sigmoid')
    # print(model.summary())
    model.load_weights(weight_path)
    # opt = optimizers.Adam(lr=3e-4)
    # model.compile(optimizer=opt,
    #               loss=bce_dice_loss,
    #               metrics=["binary_crossentropy", mean_iou, dice_coef])
    predicted_list = model.predict_generator(test_gen,
                                             steps=None,
                                             max_queue_size=10,
                                             workers=1,
                                             use_multiprocessing=False,
                                             verbose=1)
    img_name_list = test_gen.order_filenames

    num = 0
    for predict in predicted_list:
        img_name = img_name_list[num].split('/')[1]
        img_name = img_name.replace('.jpg', '.png')
        img = Convert(predict)
        cv2.imwrite(os.path.join(save_path, img_name), img)
        num += 1
        print("[INFO] {}/{}".format(num, len(img_name_list)))
Exemplo n.º 4
0
def simple(
        img_path='/media/dengpingfan/leone/dpfan/gepeng/Dataset/3Dataset/img',
        gt_path='/media/dengpingfan/leone/dpfan/gepeng/Dataset/3Dataset/gt',
        batchSize=7,
        target_size=(256, 256),
        epoch=24,
        lr=0.03,
        steps_per_epoch=663,
        model_save_path='./models/FPN/fpn_camouflage_baseline.h5'):
    seed = 1
    data_gen_args = dict(horizontal_flip=True, fill_mode='nearest')

    img_datagen = ImageDataGenerator(**data_gen_args)
    data_gen_args['rescale'] = 1. / 255
    mask_datagen = ImageDataGenerator(**data_gen_args)

    img_gen = img_datagen.flow_from_directory(img_path,
                                              batch_size=batchSize,
                                              target_size=target_size,
                                              shuffle=True,
                                              class_mode=None,
                                              seed=seed)

    mask_gen = mask_datagen.flow_from_directory(gt_path,
                                                color_mode='grayscale',
                                                batch_size=batchSize,
                                                target_size=target_size,
                                                shuffle=True,
                                                class_mode=None,
                                                seed=seed)
    train_gen = zip(img_gen, mask_gen)
    # model = Xnet(backbone_name='vgg16', encoder_weights='imagenet', decoder_block_type='transpose')
    model = FPN(backbone_name='resnet50', classes=1, activation='sigmoid')
    print(model.summary())
    opt = optimizers.SGD(lr=lr, momentum=0.9, decay=0.0001)
    model.compile(loss='binary_crossentropy',
                  optimizer=opt,
                  metrics=['binary_accuracy'])

    save_best = callbacks.ModelCheckpoint(filepath=model_save_path,
                                          monitor='loss',
                                          save_best_only=True,
                                          verbose=1)
    early_stopping = callbacks.EarlyStopping(monitor='val_loss',
                                             patience=30,
                                             verbose=2,
                                             mode='min')
    callbacks_list = [save_best, early_stopping]

    model.fit_generator(train_gen,
                        steps_per_epoch=steps_per_epoch,
                        epochs=epoch,
                        verbose=1,
                        callbacks=callbacks_list)
Exemplo n.º 5
0
def define_model(architecture='Unet', BACKBONE='resnet34', input_shape=(None, None, 4),encoder_weights=None):
    print('In define_model function')
    if architecture == 'Unet':
        model = Unet(BACKBONE, classes=3, activation='softmax', encoder_weights=encoder_weights, input_shape=input_shape)
        print('Unet model defined')
    elif architecture == 'FPN':
        model = FPN(BACKBONE, classes=3, activation='softmax', encoder_weights=encoder_weights, input_shape=input_shape)
        print('FPN model defined')
    elif architecture == 'Linknet':
        model = Linknet(BACKBONE, classes=3, activation='softmax', encoder_weights=encoder_weights, input_shape=input_shape)
        print('Linknet model defined')
    elif architecture == 'PSPNet':
        model = PSPNet(BACKBONE, classes=3, activation='softmax', encoder_weights=encoder_weights, input_shape=input_shape)
        print('PSPNet model defined')
    return model
Exemplo n.º 6
0
def get_model(name, in_shape, n_classes, backend='resnet34'):
    if name == 'fpn':
        return FPN(backbone_name=backend,
                   input_shape=in_shape,
                   classes=n_classes,
                   encoder_weights=None)
    if name == 'unet':
        return Unet(backbone_name=backend,
                    input_shape=in_shape,
                    classes=n_classes,
                    encoder_weights=None)
    if name == 'pspnet':
        return PSPNet50(input_shape=in_shape, n_labels=n_classes)
    if name == 'deeplab':
        return Deeplabv3(input_shape=in_shape, classes=n_classes, weights=None)
    if name == 'biard':
        return biard_net(in_shape=in_shape, n_classes=n_classes)
    raise ValueError("Unknown model name")
Exemplo n.º 7
0
def build_pretrained_model(model_type,
                           backbone_name,
                           encoder_weights,
                           freeze_encoder,
                           activation='sigmoid'):
    if model_type == "Unet":
        return Unet(backbone_name=backbone_name,
                    encoder_weights=encoder_weights,
                    freeze_encoder=freeze_encoder,
                    activation=activation)
    elif model_type == "FPN":
        return FPN(backbone_name=backbone_name,
                   encoder_weights=encoder_weights,
                   freeze_encoder=freeze_encoder,
                   activation=activation)
    elif model_type == "Linknet":
        return Linknet(backbone_name=backbone_name,
                       encoder_weights=encoder_weights,
                       freeze_encoder=freeze_encoder,
                       activation=activation)
    else:
        print('Pretrained model type is not supported.')
        return None
class unet():
    def __init__(self,
                 architecture,
                 backbone,
                 image_net,
                 single_labels,
                 colors,
                 address,
                 annotation_file,
                 image_folder,
                 annotation_folder,
                 bodyparts,
                 BATCH_SIZE=5,
                 train_flag=1,
                 annotation_assistance=0,
                 lr=0.001,
                 loss="Weighted Categorical_cross_entropy",
                 markersize=13):

        self.lr_drop = 25

        self.loss_function = loss

        self.markerSize = int(markersize)

        self.BATCH_SIZE = int(BATCH_SIZE)  # the higher the better
        self.annotation_file = os.path.join(address, annotation_file)
        self.bodyparts = bodyparts
        self.colors = colors
        self.error = 0
        self.single_labels = single_labels
        self.architecture = architecture
        self.backbone = backbone
        self.image_net = image_net
        if self.image_net == 'None':
            self.image_net = None
        self.annotation_assistance = annotation_assistance
        self.learning_rate = float(lr)
        self.train_flag = train_flag
        self.num_bodyparts = len(self.bodyparts)
        self.image_folder = image_folder
        self.annotation_folder = annotation_folder
        self.address = address

        try:
            file = open(self.address + os.sep + 'annotation_options.txt')
            self.options = file.readlines()
        except:
            wx.MessageBox(
                'Error in reading the annotation_options, please check the existence or choose'
                ' annotation_options\n '
                'annotation_options reading error ', 'Error!',
                wx.OK | wx.ICON_ERROR)
            return
        try:
            configuration_file = self.address + os.sep + 'file_configuration.txt'
            file = open(configuration_file)
            self.pref = file.readlines()
        except:
            wx.MessageBox(
                'Error in reading the  configuration file, please check the address or existence'
                'configuration file \n '
                'configuration file_error', 'Error!', wx.OK | wx.ICON_ERROR)
            self.error = 1
            return

        self.scorer = self.options[1][:-1]

        self.number_videos = len(self.pref) - 2

        for i in range(0, self.number_videos):
            ind = self.pref[2 + i].rfind(os.sep)
            if self.pref[2 + i][ind + 1:-1] in self.annotation_file:
                self.name_video_list = self.pref[2 + i][ind + 1:-1]
                break
        try:
            self.dataFrame = pd.read_pickle(self.annotation_file)
        except:
            wx.MessageBox(
                'Annotations reading error\n '
                'Annotation not found', 'Error!', wx.OK | wx.ICON_ERROR)
            self.error = 1
            return

        self.annotated = np.where(
            np.bitwise_and((np.isnan(self.dataFrame.iloc[:,
                                                         0].values) == False),
                           self.dataFrame.iloc[:, 0].values > 0) == True)[0]
        if train_flag == 1:
            self.train()

        if self.error != 1:
            self.test()

    def weighted_categorical_crossentropy(self, weights):

        weights = K.variable(weights)

        def loss(y_true, y_pred):
            # scale predictions so that the class probas of each sample sum to 1
            y_pred /= K.sum(y_pred, axis=-1, keepdims=True)
            # clip to prevent NaN's and Inf's
            y_pred = K.clip(y_pred, K.epsilon(), 1 - K.epsilon())
            # calc
            loss = y_true * K.log(y_pred) * weights
            loss = -K.sum(loss, -1)
            return loss

        return loss

    def train(self):

        # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
        # config = tf.ConfigProto(gpu_options=gpu_options)
        session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True, per_process_gpu_memory_fraction=0.6))
        session = tf.Session(config=session_config)

        addresss = self.address

        warnings.filterwarnings('ignore',
                                category=UserWarning,
                                module='skimage')
        seed = 42

        np.random.seed(10)

        print('Getting and resizing train images and masks ... ')
        sys.stdout.flush()
        counter = 0

        self.IMG_WIDTH = 288  # for faster computing on kaggle
        self.IMG_HEIGHT = 288  # for faster computing on kaggle

        counter = 0

        files_original_name = list()

        #self.num_bodyparts =1

        if len(self.annotated) == 0:
            wx.MessageBox(
                'Did you save your annotation?\n '
                'No annotation found in your file, please save and re-run',
                'Error!', wx.OK | wx.ICON_ERROR)
            self.error = 1
            return

        for i in range(0, len(self.annotated)):
            files_original_name.append(self.dataFrame[
                self.dataFrame.columns[0]]._stat_axis[self.annotated[i]][7:])

        img = imread(self.image_folder + os.sep + files_original_name[0])

        IMG_CHANNELS = len(np.shape(img))
        self.IMG_CHANNELS = IMG_CHANNELS

        self.file_name_for_prediction_confidence = files_original_name

        X_train = np.zeros((len(
            self.annotated), self.IMG_HEIGHT, self.IMG_WIDTH, IMG_CHANNELS),
                           dtype=np.uint8)
        Y_train = np.zeros((len(self.annotated), self.IMG_HEIGHT,
                            self.IMG_WIDTH, self.num_bodyparts + 1),
                           dtype=np.int)
        New_train = np.zeros(
            (len(self.annotated), self.IMG_HEIGHT, self.IMG_WIDTH),
            dtype=np.int)

        for l in range(0, len(self.annotated)):
            img = imread(self.image_folder + os.sep + files_original_name[l])

            # mask_ = np.zeros((np.shape(img)[0],np.shape(img)[1],self.num_bodyparts))
            mask_ = np.zeros(
                (np.shape(img)[0], np.shape(img)[1], self.num_bodyparts))
            img = resize(img, (self.IMG_HEIGHT, self.IMG_WIDTH),
                         mode='constant',
                         preserve_range=True)

            X_train[counter] = img

            for j in range(0, self.num_bodyparts):
                mask_single_label = np.zeros((mask_.shape[0], mask_.shape[1]))

                #if annotation was assisted, x is negative

                points = np.asarray([
                    self.dataFrame[self.dataFrame.columns[j * 2]].values[
                        self.annotated[l]],
                    self.dataFrame[self.dataFrame.columns[j * 2 + 1]].values[
                        self.annotated[l]]
                ],
                                    dtype=float)
                points = np.abs(points)

                if np.isnan(points[0]):
                    continue

                cv2.circle(mask_single_label, (int(round(
                    (points[0] * (2**4)))), int(round(points[1] * (2**4)))),
                           int(round(self.markerSize * (2**4))),
                           (255, 255, 255),
                           thickness=-1,
                           shift=4)
                mask_[:, :, j] = mask_single_label

            mask_ = resize(mask_, (self.IMG_HEIGHT, self.IMG_WIDTH),
                           mode='constant',
                           preserve_range=True)
            a, mask_ = cv2.threshold(mask_, 200, 255, cv2.THRESH_BINARY)
            mask_ = mask_ / 255.0
            if len(np.shape(mask_)) == 2:
                mask_new = np.zeros(
                    (np.shape(mask_)[0], np.shape(mask_)[1], 1))
                mask_new[:, :, 0] = mask_
                mask_ = mask_new

            for j in range(0, self.num_bodyparts):
                New_train[counter] = New_train[counter] + mask_[:, :,
                                                                j] * (j + 1)

            # alternative method to build the ground truth
            # temp = temp + 1
            # temp[temp == 0] = 1
            # temp[temp > 1] = 0
            # Y_train[counter, :, :,1:] = mask_
            # Y_train[counter,:,:,0] = temp
            counter += 1
            #

        try:
            Y_train = tf.keras.utils.to_categorical(
                New_train, num_classes=self.num_bodyparts + 1)
        except:
            wx.MessageBox(
                'two or more labels are overlapping!\n '
                'Check annotation or re-perform the labeling operation',
                'Error!', wx.OK | wx.ICON_ERROR)
            self.error = 1
            return

        counter = 0

        from segmentation_models import get_preprocessing

        self.processer = get_preprocessing(self.backbone)

        X_train = self.processer(X_train)

        print('Done!')

        from tensorflow.keras.preprocessing import image

        # Creating the training Image and Mask generator
        image_datagen = image.ImageDataGenerator(shear_range=0.5,
                                                 rotation_range=50,
                                                 zoom_range=0.2,
                                                 width_shift_range=0.2,
                                                 height_shift_range=0.2,
                                                 fill_mode='reflect')
        mask_datagen = image.ImageDataGenerator(shear_range=0.5,
                                                rotation_range=50,
                                                zoom_range=0.2,
                                                width_shift_range=0.2,
                                                height_shift_range=0.2,
                                                fill_mode='reflect')

        # Keep the same seed for image and mask generators so they fit together

        image_datagen.fit(X_train[:int(X_train.shape[0] * 0.9)],
                          augment=True,
                          seed=seed)
        mask_datagen.fit(Y_train[:int(Y_train.shape[0] * 0.9)],
                         augment=True,
                         seed=seed)

        x = image_datagen.flow(X_train[:int(X_train.shape[0] * 0.9)],
                               batch_size=self.BATCH_SIZE,
                               shuffle=True,
                               seed=seed)
        y = mask_datagen.flow(Y_train[:int(Y_train.shape[0] * 0.9)],
                              batch_size=self.BATCH_SIZE,
                              shuffle=True,
                              seed=seed)

        # Creating the validation Image and Mask generator
        image_datagen_val = image.ImageDataGenerator()
        mask_datagen_val = image.ImageDataGenerator()

        image_datagen_val.fit(X_train[int(X_train.shape[0] * 0.9):],
                              augment=True,
                              seed=seed)
        mask_datagen_val.fit(Y_train[int(Y_train.shape[0] * 0.9):],
                             augment=True,
                             seed=seed)

        x_val = image_datagen_val.flow(X_train[int(X_train.shape[0] * 0.9):],
                                       batch_size=self.BATCH_SIZE,
                                       shuffle=True,
                                       seed=seed)
        y_val = mask_datagen_val.flow(Y_train[int(Y_train.shape[0] * 0.9):],
                                      batch_size=self.BATCH_SIZE,
                                      shuffle=True,
                                      seed=seed)

        train_generator = zip(x, y)
        val_generator = zip(x_val, y_val)

        from segmentation_models import Unet, PSPNet, Linknet, FPN
        from segmentation_models.losses import CategoricalFocalLoss
        from segmentation_models.utils import set_trainable
        import segmentation_models
        from tensorflow.keras.optimizers import RMSprop, SGD

        if self.architecture == 'Linknet':

            self.model = Linknet(self.backbone,
                                 classes=self.num_bodyparts + 1,
                                 activation='softmax',
                                 encoder_weights=self.image_net,
                                 input_shape=(self.IMG_WIDTH, self.IMG_HEIGHT,
                                              self.IMG_CHANNELS))

        elif self.architecture == 'unet':

            self.model = Unet(self.backbone,
                              classes=self.num_bodyparts + 1,
                              activation='softmax',
                              encoder_weights=self.image_net,
                              input_shape=(self.IMG_WIDTH, self.IMG_HEIGHT,
                                           self.IMG_CHANNELS))
        elif self.architecture == 'PSPnet':
            self.model = PSPNet(self.backbone,
                                classes=self.num_bodyparts + 1,
                                activation='softmax',
                                encoder_weights=self.image_net,
                                input_shape=(self.IMG_WIDTH, self.IMG_HEIGHT,
                                             self.IMG_CHANNELS))

        elif self.architecture == 'FPN':
            self.model = FPN(self.backbone,
                             classes=self.num_bodyparts + 1,
                             activation='softmax',
                             encoder_weights=self.image_net,
                             input_shape=(self.IMG_WIDTH, self.IMG_HEIGHT,
                                          self.IMG_CHANNELS))

        weights = np.zeros((1, self.num_bodyparts + 1), dtype=float)
        weight = 1.0 / self.num_bodyparts

        num_zeros = 1
        # while (weight * 100 < 1):
        #     weight = weight * 100
        #     num_zeros += 1
        #
        # weight = int(weight * 100) / np.power(100, num_zeros)
        weights[0, 1:] = weight
        weights[0, 0] = 0.01 * len(self.bodyparts)

        while weights[0, 0] > weights[0, 1]:
            weights[0, 0] = weights[0, 0] / 10
            num_zeros += 1

        for i in range(1, len(self.bodyparts) + 1):
            weights[0, i] = weights[0, i] - 10**-(num_zeros + 1)

        if self.loss_function == "Weighted Categorical_cross_entropy":
            loss = self.weighted_categorical_crossentropy(weights)
        else:
            loss = segmentation_models.losses.DiceLoss(class_weights=weights)
        metric = segmentation_models.metrics.IOUScore(class_weights=weights,
                                                      per_image=True)
        self.model.compile(optimizer=RMSprop(lr=self.learning_rate),
                           loss=loss,
                           metrics=[metric])
        earlystopper = EarlyStopping(patience=6, verbose=1)
        #
        checkpointer = ModelCheckpoint(os.path.join(self.address, 'Unet.h5'),
                                       verbose=1,
                                       save_best_only=True)
        reduce_lr = keras.callbacks.LearningRateScheduler(self.lr_scheduler)

        #
        # model.fit_generator(train_generator, validation_data=val_generator, validation_steps=10, steps_per_epoch=50,
        #                                epochs=2, callbacks=[earlystopper, checkpointer],verbose=1)
        # model.load_weights(self.address + 'Temp_weights.h5')

        # set_trainable(model)
        #
        self.model.fit_generator(
            train_generator,
            validation_data=val_generator,
            steps_per_epoch=20,
            validation_steps=5,
            epochs=100,
            callbacks=[earlystopper, checkpointer, reduce_lr],
            verbose=1)

    def test(self):
        import segmentation_models
        from segmentation_models import Unet
        session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True, per_process_gpu_memory_fraction=0.6))
        session = tf.Session(config=session_config)
        weights = np.zeros((1, self.num_bodyparts + 1), dtype=float)
        weight = 1.0 / self.num_bodyparts

        num_zeros = 1
        # while (weight * 100 < 1):
        #     weight = weight * 100
        #     num_zeros += 1
        #
        # weight = int(weight * 100) / np.power(100, num_zeros)
        weights[0, 1:] = weight
        #weights[0, 0] = 1 - np.sum(weights[0, 1:])
        weights[0, 0] = 0.01 * len(self.bodyparts)
        while weights[0, 0] > weights[0, 1]:
            weights[0, 0] = weights[0, 0] / 10
            num_zeros += 1

        for i in range(1, len(weights)):
            weights[i] = weights[i] - np.power(10, num_zeros + 1)
        weights = weights[0]
        if self.loss_function == "Weighted Categorical_cross_entropy":
            loss = self.weighted_categorical_crossentropy(weights)
        else:
            loss = segmentation_models.losses.DiceLoss(class_weights=weights)

        metric = segmentation_models.metrics.IOUScore(class_weights=weights,
                                                      per_image=True)

        model = load_model(os.path.join(self.address, 'Unet.h5'),
                           custom_objects={
                               'loss': loss,
                               'dice_loss':
                               segmentation_models.losses.DiceLoss,
                               'iou_score': metric
                           })

        OUTPUT = os.path.join(self.address,
                              self.name_video_list + '_prediction')

        try:
            os.mkdir(OUTPUT)

        except:
            pass
        files = os.listdir(self.image_folder)
        files_original_name = list()

        self.annotated = np.where(
            np.bitwise_and((np.isnan(self.dataFrame.iloc[:,
                                                         0].values) == False),
                           self.dataFrame.iloc[:, 0].values > 0) == True)[0]

        # if len(self.annotated)==0:
        #     wx.MessageBox('Did you save your annotation?\n '
        #                   'No annotation found in your file, please save and re-run'
        #                   , 'Error!', wx.OK | wx.ICON_ERROR)
        #     self.error = 1
        #     return

        for i in range(0, len(self.annotated)):
            files_original_name.append(files[self.annotated[i]])
        #files = np.unique(files_original_name)
        img = imread(self.image_folder + os.sep + files[0])

        IMG_CHANNELS = len(np.shape(img))
        self.IMG_CHANNELS = IMG_CHANNELS
        self.IMG_WIDTH = 288
        self.IMG_HEIGHT = 288
        counter = 0

        if self.annotation_assistance == 0:

            self.X_test = np.zeros(
                (len(files) - len(self.annotated), self.IMG_HEIGHT,
                 self.IMG_WIDTH, self.IMG_CHANNELS),
                dtype=np.uint8)

            self.testing_index = list()

            for l in range(0, len(files)):
                if l not in self.annotated:
                    self.testing_index.append(l)
                    img = imread(self.image_folder + os.sep +
                                 files[l])[:, :, :self.IMG_CHANNELS]
                    img = resize(img, (self.IMG_HEIGHT, self.IMG_WIDTH),
                                 mode='constant',
                                 preserve_range=True)
                    self.X_test[counter] = img
                    counter += 1

            self.testing_index = np.asarray(self.testing_index)

        else:
            if os.path.isfile(
                    os.path.join(
                        os.path.dirname(self.annotation_file),
                        self.name_video_list + '_index_annotation.txt')):
                self.pref_ann = open(
                    os.path.join(
                        os.path.dirname(self.annotation_file),
                        self.name_video_list + '_index_annotation_auto.txt'),
                    'r')
            temporary = self.pref_ann.readlines()
            for i in range(0, len(temporary)):
                temporary[i] = temporary[i][:-1]
            temporary = np.asarray(temporary)
            self.frame_selected_for_annotation = temporary.astype(int)
            self.frame_selected_for_annotation = np.sort(
                self.frame_selected_for_annotation)

            if len(self.frame_selected_for_annotation) == 0:
                filename = self.ask()
                if filename == "":
                    return
                else:
                    num_frame_automatic = int(filename)
                    if num_frame_automatic > len(files):
                        wx.MessageBox(
                            'Error! Number of frames must be lower than the total number of frames\n '
                            'Input error', 'Error!', wx.OK | wx.ICON_ERROR)
                    my_list = list(
                        range(0, len(files))
                    )  # list of integers from 1 to end                # adjust this boundaries to fit your needs
                    random.shuffle(my_list)
                    self.frame_selected_for_annotation = my_list[
                        0:num_frame_automatic]
                    output = open(
                        os.path.join(
                            os.path.dirname(self.annotation_file),
                            self.name_video_list +
                            '_index_annotation_auto.txt'), 'w')
                    for fra in self.frame_selected_for_annotation:
                        output.writelines(str(fra))
                        output.writelines('\n')
                    output.close()

            self.X_test = np.zeros(
                (len(self.frame_selected_for_annotation), self.IMG_HEIGHT,
                 self.IMG_WIDTH, self.IMG_CHANNELS),
                dtype=np.uint8)

            for l in range(0, len(files)):
                if l in self.frame_selected_for_annotation:
                    img = imread(self.image_folder + os.sep +
                                 files[l])[:, :, :self.IMG_CHANNELS]
                    img = resize(img, (self.IMG_HEIGHT, self.IMG_WIDTH),
                                 mode='constant',
                                 preserve_range=True)
                    self.X_test[counter] = img
                    counter += 1

        #X_test = X_test / 255.0
        seed = 42
        # Creating the training Image and Mask generator
        # image_datagen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
        # # Keep the same seed for image and mask generators so they fit together
        # image_datagen.fit(self.X_test, augment=False,seed=seed)
        # x = image_datagen.flow(self.X_test, batch_size=1, shuffle=False, seed=seed)

        from segmentation_models import get_preprocessing

        self.processer = get_preprocessing(self.backbone)

        self.X_test = self.processer(self.X_test)

        x = self.X_test
        preds_test = model.predict(x, verbose=0)
        # Threshold predictions
        # preds_train_t = (preds_train > 0.5).astype(np.uint8)
        # K.clear_session()
        # preds_val_t = (preds_val > 0.5).astype(np.uint8)
        #preds_test_t = (preds_test > 0.5).astype(np.uint8)
        preds_test_t = preds_test
        # Create list of upsampled test masks
        preds_test_upsampled = []
        address_single_labels = os.path.join(
            self.address, self.name_video_list + '_Single_labels')
        self.annotated = np.where(
            np.bitwise_and((np.isnan(self.dataFrame.iloc[:,
                                                         0].values) == False),
                           self.dataFrame.iloc[:, 0].values > 0) == True)[0]
        try:
            os.mkdir(address_single_labels)
        except:
            pass

        if self.single_labels == 'Yes':

            if self.annotation_assistance == 1:
                self.testing_index = self.frame_selected_for_annotation
            # PREDICT AND SAVE SINGLE LABELS IMAGES
            for i in range(0, len(preds_test)):
                img = imread(self.image_folder + os.sep +
                             files[self.testing_index[i]])
                sizes_test = np.shape(img)[:-1]
                for j in range(0, self.num_bodyparts + 1):
                    preds_test_upsampled = resize(np.squeeze(
                        preds_test_t[i, :, :,
                                     j]), (sizes_test[0], sizes_test[1]),
                                                  mode='constant',
                                                  preserve_range=True)
                    preds_test_upsampled = (preds_test_upsampled *
                                            255).astype(int)
                    cv2.imwrite(
                        address_single_labels + os.sep + ("{:02d}".format(j)) +
                        files[self.testing_index[i]], preds_test_upsampled)

        #let us create a dataframe to store prediction confidence

        imlist = os.listdir(self.image_folder)
        self.index = np.sort(imlist)
        a = np.empty((
            len(self.index),
            3,
        ))
        self.dataFrame3 = None
        a[:] = np.nan
        self.scorer = 'user'
        for bodypart in self.bodyparts:
            index = pd.MultiIndex.from_product(
                [[self.scorer], [bodypart], ['x', 'y', 'confidence']],
                names=['scorer', 'bodyparts', 'coords'])
            frame = pd.DataFrame(a, columns=index, index=imlist)
            self.dataFrame3 = pd.concat([self.dataFrame3, frame], axis=1)
        num_columns = len(self.dataFrame3.columns)

        # if this is a regular testing, all of the images (but the annotated ones) have to be analyzed
        if self.annotation_assistance == 0:

            out = skvideo.io.FFmpegWriter(os.path.join(
                OUTPUT, self.name_video_list + '.avi'),
                                          outputdict={'-b': '300000000'})

            for i in range(0, len(preds_test)):

                results = np.zeros((self.num_bodyparts * 2))
                results_plus_conf = np.zeros((self.num_bodyparts * 3))

                #here to update dataframe with annotations
                img = imread(self.image_folder + os.sep +
                             files[self.testing_index[i]])
                sizes_test = np.shape(img)[:-1]

                for j in range(0, self.num_bodyparts + 1):
                    if j == 0: continue

                    preds_test_upsampled = resize(np.squeeze(
                        preds_test_t[i, :, :,
                                     j]), (sizes_test[0], sizes_test[1]),
                                                  mode='constant',
                                                  preserve_range=True)
                    preds_test_upsampled = (preds_test_upsampled * 255).astype(
                        np.uint8)
                    results[(j - 1) * 2:(j - 1) * 2 +
                            2] = self.prediction_to_annotation(
                                preds_test_upsampled)
                    results_plus_conf[(j - 1) * 3:(j - 1) * 3 +
                                      3] = self.compute_confidence(
                                          preds_test_upsampled)
                    self.dataFrame[self.dataFrame.columns[(j - 1) * 2]].values[
                        self.testing_index[i]] = -results[(j - 1) * 2]
                    self.dataFrame[self.dataFrame.columns[
                        (j - 1) * 2 +
                        1]].values[self.testing_index[i]] = results[(j - 1) * 2
                                                                    + 1]
                    self.dataFrame3[self.dataFrame3.columns[
                        (j - 1) *
                        3]].values[self.testing_index[i]] = -results_plus_conf[
                            (j - 1) * 3]
                    self.dataFrame3[self.dataFrame3.columns[
                        (j - 1) * 3 +
                        1]].values[self.testing_index[i]] = results_plus_conf[
                            (j - 1) * 3 + 1]
                    self.dataFrame3[self.dataFrame3.columns[
                        (j - 1) * 3 +
                        2]].values[self.testing_index[i]] = results_plus_conf[
                            (j - 1) * 3 + 2]

                self.plot_annotation(img, results,
                                     files[self.testing_index[i]], OUTPUT, out)

        else:
            #if annotation assistance is requested, we only want to annotate the random frames extracted by the user
            for i in range(0, len(preds_test)):
                results = np.zeros((self.num_bodyparts * 2))
                results_plus_conf = np.zeros((self.num_bodyparts * 3))
                # here to update dataframe with annotations
                img = imread(self.image_folder +
                             files[self.frame_selected_for_annotation[i]])
                sizes_test = np.shape(img)[:-1]
                for j in range(0, self.num_bodyparts + 1):
                    if j == 0: continue

                    preds_test_upsampled = resize(np.squeeze(
                        preds_test_t[i, :, :,
                                     j]), (sizes_test[0], sizes_test[1]),
                                                  mode='constant',
                                                  preserve_range=True)
                    preds_test_upsampled = (preds_test_upsampled * 255).astype(
                        np.uint8)
                    results[(j - 1) * 2:(j - 1) * 2 +
                            2] = self.prediction_to_annotation(
                                preds_test_upsampled)
                    results_plus_conf[(j - 1) * 3:(j - 1) * 3 +
                                      3] = self.compute_confidence(
                                          preds_test_upsampled)
                    self.plot_annotation(
                        img, results,
                        files[self.frame_selected_for_annotation[i]], OUTPUT,
                        out)
                    self.dataFrame[self.dataFrame.columns[(j - 1) * 2]].values[
                        self.frame_selected_for_annotation[i]] = -results[
                            (j - 1) * 2]
                    self.dataFrame[self.dataFrame.columns[
                        (j - 1) * 2 + 1]].values[
                            self.frame_selected_for_annotation[i]] = results[
                                (j - 1) * 2 + 1]
                    self.dataFrame3[self.dataFrame3.columns[
                        (j - 1) *
                        3]].values[self.frame_selected_for_annotation[
                            i]] = -results_plus_conf[(j - 1) * 3]
                    self.dataFrame3[self.dataFrame3.columns[
                        (j - 1) * 3 +
                        1]].values[self.frame_selected_for_annotation[
                            i]] = results_plus_conf[(j - 1) * 3 + 1]
                    self.dataFrame3[self.dataFrame3.columns[
                        (j - 1) * 3 +
                        2]].values[self.frame_selected_for_annotation[
                            i]] = results_plus_conf[(j - 1) * 3 + 2]

        self.dataFrame.to_pickle(self.annotation_file)
        self.dataFrame.to_csv(os.path.join(self.annotation_file + ".csv"))
        out.close()
        try:
            self.dataFrame3.to_csv(
                os.path.join(self.address,
                             self.annotation_file + "_with_confidence.csv"))
        except:
            wx.MessageBox(
                'Error in writing the results file'
                'file not accessible\n '
                'Error in writing', 'Error!', wx.OK | wx.ICON_ERROR)

    def lr_scheduler(self, epoch):
        return self.learning_rate * (0.5**(epoch // self.lr_drop))

    def prediction_to_annotation(self, annotation):
        # compute_corresponding_annotation_point
        # annotation = cv2.cvtColor(annotation, cv2.COLOR_BGR2GRAY)
        thresh, annotation = cv2.threshold(annotation, 150, 255,
                                           cv2.THRESH_BINARY)
        contour, hierarchy = cv2.findContours(annotation, cv2.RETR_EXTERNAL,
                                              cv2.CHAIN_APPROX_NONE)
        if contour != []:
            max_area = 0
            i_max = []
            for cc in contour:
                mom = cv2.moments(cc)
                area = mom['m00']
                if area > max_area:
                    max_area = area
                    i_max = cc

            if len(i_max) == 0:

                xc = 1
                yc = -1
            else:
                center = cv2.moments(i_max)

                xc = center['m10'] / center['m00']
                yc = center['m01'] / center['m00']

        else:
            #maybe one joint is missing, but the other were correctly identified
            # self.dataFrame[self.dataFrame.columns[(i - 1) * 2]].values[j] = -1
            # self.dataFrame[self.dataFrame.columns[(i - 1) * 2 + 1]].values[j] = -1
            xc = 1
            yc = -1

        return xc, yc

        # self.statusbar.SetStatusText("File saved")
        # MainFrame.updateZoomPan(self)
        #
        # copyfile(os.path.join(self.filename + ".csv"), os.path.join(self.filename + "_MANUAL.csv"))
        # self.dataFrame.to_csv(os.path.join(self.filename + ".csv"))
        # copyfile(os.path.join(self.filename + ".csv"), os.path.join(self.filename + "_MANUAL.csv"))
        # self.dataFrame.to_pickle(self.filename)  # where to save it, usually as a .pkl
        # wx.PyCommandEvent(wx.EVT_BUTTON.typeId, self.load.GetId())

    def compute_confidence(self, annotation):
        # compute_corresponding_annotation_point
        # annotation = cv2.cvtColor(annotation, cv2.COLOR_BGR2GRAY)
        confidence_image = copy.copy(annotation)
        raw_image = copy.copy(annotation)
        confidence_image[np.where(confidence_image != 0)] = 1

        mask = np.zeros_like(confidence_image)
        thresh, annotation = cv2.threshold(annotation, 150, 255,
                                           cv2.THRESH_BINARY)
        contour, hierarchy = cv2.findContours(annotation, cv2.RETR_EXTERNAL,
                                              cv2.CHAIN_APPROX_NONE)
        if contour != []:
            max_area = 0
            i_max = []
            for cc in contour:
                mom = cv2.moments(cc)
                area = mom['m00']
                if area > max_area:
                    max_area = area
                    i_max = cc

            if len(i_max) == 0:
                xc = 1
                yc = -1
                confidence = 0
            else:

                center = cv2.moments(i_max)
                cv2.drawContours(mask, [i_max], -1, (255, 255, 255), -1)

                mask[np.where(mask != 0)] = 1

                mean_values = np.multiply(confidence_image, mask)
                confidence = np.mean(raw_image[np.where(mean_values != 0)])
                confidence = confidence / 255.0

                xc = center['m10'] / center['m00']
                yc = center['m01'] / center['m00']

        else:
            # maybe one joint is missing, but the other were correctly identified
            # self.dataFrame[self.dataFrame.columns[(i - 1) * 2]].values[j] = -1
            # self.dataFrame[self.dataFrame.columns[(i - 1) * 2 + 1]].values[j] = -1
            xc = 1
            yc = -1
            confidence = 0

        return xc, yc, confidence

    def plot_annotation(self, image, points, name, OUTPUT, out):
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        for i in range(0, len(self.bodyparts)):
            if not np.isnan(points[i * 2] and points[i * 2 + 1]):
                cv2.circle(image, (int(round(
                    (points[i * 2] *
                     (2**4)))), int(round(points[i * 2 + 1] * (2**4)))),
                           int(round(self.markerSize * (2**4))),
                           self.colors[i] * 255,
                           thickness=-1,
                           shift=4)

        cv2.imwrite(os.path.join(OUTPUT, name), image)

        out.writeFrame(cv2.cvtColor(image, cv2.COLOR_RGB2BGR))

    def ask(
        self,
        parent=None,
        message='Attention:\n you did not input the number of frames to automatically detect in previous interface \n Please, fill the following textbox with such number or cancel'
    ):
        default_value = ""
        dlg = wx.TextEntryDialog(parent, message, value=default_value)
        dlg.ShowModal()
        result = dlg.GetValue()
        dlg.Destroy()
        return result
    def train(self):

        # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
        # config = tf.ConfigProto(gpu_options=gpu_options)
        session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True, per_process_gpu_memory_fraction=0.6))
        session = tf.Session(config=session_config)

        addresss = self.address

        warnings.filterwarnings('ignore',
                                category=UserWarning,
                                module='skimage')
        seed = 42

        np.random.seed(10)

        print('Getting and resizing train images and masks ... ')
        sys.stdout.flush()
        counter = 0

        self.IMG_WIDTH = 288  # for faster computing on kaggle
        self.IMG_HEIGHT = 288  # for faster computing on kaggle

        counter = 0

        files_original_name = list()

        #self.num_bodyparts =1

        if len(self.annotated) == 0:
            wx.MessageBox(
                'Did you save your annotation?\n '
                'No annotation found in your file, please save and re-run',
                'Error!', wx.OK | wx.ICON_ERROR)
            self.error = 1
            return

        for i in range(0, len(self.annotated)):
            files_original_name.append(self.dataFrame[
                self.dataFrame.columns[0]]._stat_axis[self.annotated[i]][7:])

        img = imread(self.image_folder + os.sep + files_original_name[0])

        IMG_CHANNELS = len(np.shape(img))
        self.IMG_CHANNELS = IMG_CHANNELS

        self.file_name_for_prediction_confidence = files_original_name

        X_train = np.zeros((len(
            self.annotated), self.IMG_HEIGHT, self.IMG_WIDTH, IMG_CHANNELS),
                           dtype=np.uint8)
        Y_train = np.zeros((len(self.annotated), self.IMG_HEIGHT,
                            self.IMG_WIDTH, self.num_bodyparts + 1),
                           dtype=np.int)
        New_train = np.zeros(
            (len(self.annotated), self.IMG_HEIGHT, self.IMG_WIDTH),
            dtype=np.int)

        for l in range(0, len(self.annotated)):
            img = imread(self.image_folder + os.sep + files_original_name[l])

            # mask_ = np.zeros((np.shape(img)[0],np.shape(img)[1],self.num_bodyparts))
            mask_ = np.zeros(
                (np.shape(img)[0], np.shape(img)[1], self.num_bodyparts))
            img = resize(img, (self.IMG_HEIGHT, self.IMG_WIDTH),
                         mode='constant',
                         preserve_range=True)

            X_train[counter] = img

            for j in range(0, self.num_bodyparts):
                mask_single_label = np.zeros((mask_.shape[0], mask_.shape[1]))

                #if annotation was assisted, x is negative

                points = np.asarray([
                    self.dataFrame[self.dataFrame.columns[j * 2]].values[
                        self.annotated[l]],
                    self.dataFrame[self.dataFrame.columns[j * 2 + 1]].values[
                        self.annotated[l]]
                ],
                                    dtype=float)
                points = np.abs(points)

                if np.isnan(points[0]):
                    continue

                cv2.circle(mask_single_label, (int(round(
                    (points[0] * (2**4)))), int(round(points[1] * (2**4)))),
                           int(round(self.markerSize * (2**4))),
                           (255, 255, 255),
                           thickness=-1,
                           shift=4)
                mask_[:, :, j] = mask_single_label

            mask_ = resize(mask_, (self.IMG_HEIGHT, self.IMG_WIDTH),
                           mode='constant',
                           preserve_range=True)
            a, mask_ = cv2.threshold(mask_, 200, 255, cv2.THRESH_BINARY)
            mask_ = mask_ / 255.0
            if len(np.shape(mask_)) == 2:
                mask_new = np.zeros(
                    (np.shape(mask_)[0], np.shape(mask_)[1], 1))
                mask_new[:, :, 0] = mask_
                mask_ = mask_new

            for j in range(0, self.num_bodyparts):
                New_train[counter] = New_train[counter] + mask_[:, :,
                                                                j] * (j + 1)

            # alternative method to build the ground truth
            # temp = temp + 1
            # temp[temp == 0] = 1
            # temp[temp > 1] = 0
            # Y_train[counter, :, :,1:] = mask_
            # Y_train[counter,:,:,0] = temp
            counter += 1
            #

        try:
            Y_train = tf.keras.utils.to_categorical(
                New_train, num_classes=self.num_bodyparts + 1)
        except:
            wx.MessageBox(
                'two or more labels are overlapping!\n '
                'Check annotation or re-perform the labeling operation',
                'Error!', wx.OK | wx.ICON_ERROR)
            self.error = 1
            return

        counter = 0

        from segmentation_models import get_preprocessing

        self.processer = get_preprocessing(self.backbone)

        X_train = self.processer(X_train)

        print('Done!')

        from tensorflow.keras.preprocessing import image

        # Creating the training Image and Mask generator
        image_datagen = image.ImageDataGenerator(shear_range=0.5,
                                                 rotation_range=50,
                                                 zoom_range=0.2,
                                                 width_shift_range=0.2,
                                                 height_shift_range=0.2,
                                                 fill_mode='reflect')
        mask_datagen = image.ImageDataGenerator(shear_range=0.5,
                                                rotation_range=50,
                                                zoom_range=0.2,
                                                width_shift_range=0.2,
                                                height_shift_range=0.2,
                                                fill_mode='reflect')

        # Keep the same seed for image and mask generators so they fit together

        image_datagen.fit(X_train[:int(X_train.shape[0] * 0.9)],
                          augment=True,
                          seed=seed)
        mask_datagen.fit(Y_train[:int(Y_train.shape[0] * 0.9)],
                         augment=True,
                         seed=seed)

        x = image_datagen.flow(X_train[:int(X_train.shape[0] * 0.9)],
                               batch_size=self.BATCH_SIZE,
                               shuffle=True,
                               seed=seed)
        y = mask_datagen.flow(Y_train[:int(Y_train.shape[0] * 0.9)],
                              batch_size=self.BATCH_SIZE,
                              shuffle=True,
                              seed=seed)

        # Creating the validation Image and Mask generator
        image_datagen_val = image.ImageDataGenerator()
        mask_datagen_val = image.ImageDataGenerator()

        image_datagen_val.fit(X_train[int(X_train.shape[0] * 0.9):],
                              augment=True,
                              seed=seed)
        mask_datagen_val.fit(Y_train[int(Y_train.shape[0] * 0.9):],
                             augment=True,
                             seed=seed)

        x_val = image_datagen_val.flow(X_train[int(X_train.shape[0] * 0.9):],
                                       batch_size=self.BATCH_SIZE,
                                       shuffle=True,
                                       seed=seed)
        y_val = mask_datagen_val.flow(Y_train[int(Y_train.shape[0] * 0.9):],
                                      batch_size=self.BATCH_SIZE,
                                      shuffle=True,
                                      seed=seed)

        train_generator = zip(x, y)
        val_generator = zip(x_val, y_val)

        from segmentation_models import Unet, PSPNet, Linknet, FPN
        from segmentation_models.losses import CategoricalFocalLoss
        from segmentation_models.utils import set_trainable
        import segmentation_models
        from tensorflow.keras.optimizers import RMSprop, SGD

        if self.architecture == 'Linknet':

            self.model = Linknet(self.backbone,
                                 classes=self.num_bodyparts + 1,
                                 activation='softmax',
                                 encoder_weights=self.image_net,
                                 input_shape=(self.IMG_WIDTH, self.IMG_HEIGHT,
                                              self.IMG_CHANNELS))

        elif self.architecture == 'unet':

            self.model = Unet(self.backbone,
                              classes=self.num_bodyparts + 1,
                              activation='softmax',
                              encoder_weights=self.image_net,
                              input_shape=(self.IMG_WIDTH, self.IMG_HEIGHT,
                                           self.IMG_CHANNELS))
        elif self.architecture == 'PSPnet':
            self.model = PSPNet(self.backbone,
                                classes=self.num_bodyparts + 1,
                                activation='softmax',
                                encoder_weights=self.image_net,
                                input_shape=(self.IMG_WIDTH, self.IMG_HEIGHT,
                                             self.IMG_CHANNELS))

        elif self.architecture == 'FPN':
            self.model = FPN(self.backbone,
                             classes=self.num_bodyparts + 1,
                             activation='softmax',
                             encoder_weights=self.image_net,
                             input_shape=(self.IMG_WIDTH, self.IMG_HEIGHT,
                                          self.IMG_CHANNELS))

        weights = np.zeros((1, self.num_bodyparts + 1), dtype=float)
        weight = 1.0 / self.num_bodyparts

        num_zeros = 1
        # while (weight * 100 < 1):
        #     weight = weight * 100
        #     num_zeros += 1
        #
        # weight = int(weight * 100) / np.power(100, num_zeros)
        weights[0, 1:] = weight
        weights[0, 0] = 0.01 * len(self.bodyparts)

        while weights[0, 0] > weights[0, 1]:
            weights[0, 0] = weights[0, 0] / 10
            num_zeros += 1

        for i in range(1, len(self.bodyparts) + 1):
            weights[0, i] = weights[0, i] - 10**-(num_zeros + 1)

        if self.loss_function == "Weighted Categorical_cross_entropy":
            loss = self.weighted_categorical_crossentropy(weights)
        else:
            loss = segmentation_models.losses.DiceLoss(class_weights=weights)
        metric = segmentation_models.metrics.IOUScore(class_weights=weights,
                                                      per_image=True)
        self.model.compile(optimizer=RMSprop(lr=self.learning_rate),
                           loss=loss,
                           metrics=[metric])
        earlystopper = EarlyStopping(patience=6, verbose=1)
        #
        checkpointer = ModelCheckpoint(os.path.join(self.address, 'Unet.h5'),
                                       verbose=1,
                                       save_best_only=True)
        reduce_lr = keras.callbacks.LearningRateScheduler(self.lr_scheduler)

        #
        # model.fit_generator(train_generator, validation_data=val_generator, validation_steps=10, steps_per_epoch=50,
        #                                epochs=2, callbacks=[earlystopper, checkpointer],verbose=1)
        # model.load_weights(self.address + 'Temp_weights.h5')

        # set_trainable(model)
        #
        self.model.fit_generator(
            train_generator,
            validation_data=val_generator,
            steps_per_epoch=20,
            validation_steps=5,
            epochs=100,
            callbacks=[earlystopper, checkpointer, reduce_lr],
            verbose=1)
Exemplo n.º 10
0
train_generator = AmateurDataFrameDataGenerator(df_train, classes_id=class_ids, batch_size=4, dim_image=dim_image)
val_generator = AmateurDataFrameDataGenerator(df_val, classes_id=class_ids, batch_size=4, dim_image=dim_image)




# preprocess input
# from segmentation_models.backbones import get_preprocessing
# preprocess_input = get_preprocessing(BACKBONE)
# x_train = preprocess_input(x_train)
# x_val = preprocess_input(x_val)

# define model

model = Unet(BACKBONE, classes=len(class_ids), encoder_weights='imagenet')
model = FPN(BACKBONE, classes=len(class_ids), encoder_weights='imagenet')
model = PSPNet(BACKBONE, classes=len(class_ids), encoder_weights='imagenet')
model = Linknet(BACKBONE, classes=len(class_ids), encoder_weights='imagenet')

model.compile('Adam', loss=bce_jaccard_loss, metrics=[iou_score])
model.summary()

modelCheckpoint = keras.callbacks.ModelCheckpoint(filepath='segmod_weights.{epoch:02d}-{val_loss:.4f}.hdf5',
                                                  monitor='val_loss',
                                                  verbose=0, save_best_only=False, save_weights_only=False,
                                                  mode='auto', period=1)
reduceLROnPlateau = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=7, verbose=1,
                                                      mode='auto', min_delta=0.001, cooldown=0, min_lr=10e-7)


model.fit_generator(generator=train_generator, steps_per_epoch=None, epochs=10, verbose=1,
Exemplo n.º 11
0
    input_layer = (config.img_w, config.img_h, config.im_bands)
    model = Unet(backbone_name=config.BACKBONE,
                 input_shape=input_layer,
                 classes=config.nb_classes,
                 activation=config.activation,
                 encoder_weights=config.encoder_weights)
    if 'pspnet' in config.network:
        model = PSPNet(backbone_name=config.BACKBONE,
                       input_shape=input_layer,
                       classes=config.nb_classes,
                       activation=config.activation,
                       encoder_weights=config.encoder_weights)
    elif 'fpn' in config.network:
        model = FPN(backbone_name=config.BACKBONE,
                    input_shape=input_layer,
                    classes=config.nb_classes,
                    activation=config.activation,
                    encoder_weights=config.encoder_weights)
    elif 'linknet' in config.network:
        model = Linknet(backbone_name=config.BACKBONE,
                        input_shape=input_layer,
                        classes=config.nb_classes,
                        activation=config.activation,
                        encoder_weights=config.encoder_weights)
    else:
        pass

    print(model.summary())
    print("Train by : {}_{}".format(config.network, config.BACKBONE))

    # sys.exit(-1)
Exemplo n.º 12
0
def train_generatorh5(params):
    from hitif_losses import dice_coef_loss_bce
    from hitif_losses import double_head_loss

    print('-' * 30)
    print('Loading and preprocessing train data...')
    print('-' * 30)

    # Prepare for splitting the training set
    imgs_ind = np.arange(number_of_imgs)
    np.random.shuffle(imgs_ind)

    # Split 80-20
    train_last_id = int(number_of_imgs * 0.80)

    # Generators
    training_generator = DataGeneratorH5(
        source_target_list_IDs=imgs_ind[0:train_last_id].copy(), **params)
    validation_generator = DataGeneratorH5(
        source_target_list_IDs=imgs_ind[train_last_id:number_of_imgs].copy(),
        **params)

    print('-' * 30)
    print('Creating and compiling model...')
    print('-' * 30)

    layers = get_feature_layers(backbone_name, 4)
    if backbone_type == 'FPN':
        model = FPN(input_shape=(None, None, num_channels),
                    classes=num_mask_channels,
                    encoder_weights=encoder_weights,
                    encoder_freeze=freezeFlag,
                    backbone_name=backbone_name,
                    activation=act_fcn,
                    encoder_features=layers)
    elif backbone_type == 'Unet':
        model = Unet(input_shape=(None, None, num_channels),
                     classes=num_mask_channels,
                     encoder_weights=encoder_weights,
                     encoder_freeze=freezeFlag,
                     backbone_name=backbone_name,
                     activation=act_fcn,
                     encoder_features=layers)
    #model.summary()
    #model.compile(optimizer=Adam(lr=1e-5), loss='binary_crossentropy', metrics=['binary_crossentropy','mean_squared_error',dice_coef, dice_coef_batch, dice_coef_loss_bce,focal_loss()])
    #model.compile(optimizer=Adam(lr=1e-3), loss=dice_coef_loss_bce, metrics=['binary_crossentropy','mean_squared_error',dice_coef, dice_coef_batch,focal_loss()])
    #model.compile(optimizer=Adam(lr=1e-3), loss=loss_fcn, metrics=['binary_crossentropy','mean_squared_error',dice_coef, dice_coef_batch,focal_loss()])
    if loss_fcn == 'dice_coef_loss_bce':
        model.compile(optimizer=Adam(lr=1e-3), loss=dice_coef_loss_bce)
    elif loss_fcn == 'double_head_loss':
        model.compile(optimizer=Adam(lr=1e-3), loss=double_head_loss)
    else:
        model.compile(optimizer=Adam(lr=1e-3), loss=loss_fcn)

    # Loading previous weights for restarting
    if oldmodelwtsfname is not None:
        if os.path.isfile(oldmodelwtsfname) and reloadFlag:
            print('-' * 30)
            print('Loading previous weights ...')

            weights = np.load(oldmodelwtsfname, allow_pickle=True)
            model.set_weights(weights)
            #model.load_weights(oldmodelwtsfname)

    checkpoint_path = get_checkpoint_path(log_dir_name)
    print("checkpoint_path:", checkpoint_path)
    model_checkpoint = ModelCheckpoint(checkpoint_path,
                                       monitor='val_loss',
                                       save_best_only=True,
                                       save_weights_only=True)
    custom_checkpoint = Checkpoints(checkpoint_path,
                                    monitor='val_loss',
                                    verbose=1)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                                  factor=0.1,
                                  patience=25,
                                  min_lr=1e-6,
                                  verbose=1,
                                  restore_best_weights=True)
    model_es = EarlyStopping(monitor='val_loss',
                             min_delta=1e-7,
                             patience=15,
                             verbose=1,
                             mode='auto')
    csv_logger = CSVLogger(csvfname, append=True)

    #my_callbacks = [reduce_lr, model_es, csv_logger]
    #my_callbacks = [model_checkpoint, reduce_lr, model_es, csv_logger]
    my_callbacks = [custom_checkpoint, reduce_lr, model_es, csv_logger]
    print('-' * 30)
    print('Fitting model encoder freeze...')
    print('-' * 30)
    if freezeFlag and num_coldstart_epochs > 0:
        model.fit_generator(generator=training_generator,
                            validation_data=validation_generator,
                            use_multiprocessing=True,
                            workers=num_gen_workers,
                            epochs=num_coldstart_epochs,
                            callbacks=my_callbacks,
                            verbose=2)

    # release all layers for training
    set_trainable(model)  # set all layers trainable and recompile model
    #model.summary()

    print('-' * 30)
    print('Fitting full model...')
    print('-' * 30)

    ## Retrain after the cold-start
    model.fit_generator(generator=training_generator,
                        validation_data=validation_generator,
                        use_multiprocessing=True,
                        workers=num_gen_workers,
                        epochs=num_finetuning_epochs,
                        callbacks=my_callbacks,
                        verbose=2)

    ## <<FIXME>>: GZ will work on it.
    # Find the last best epoch model weights and then symlink it to the modelwtsfname
    # Note that the symlink will have issues on NON-Linux OS so it is better to copy.
    '''
                 encoder_weights=config.weights,
                 n_upsample_blocks=4,
                 decoder_block_type=config.decoder_block_type,
                 classes=config.nb_class,
                 activation=config.activation)
elif config.model == "PSPNet":
    model = PSPNet(
        backbone_name=config.backbone,
        encoder_weights=config.weights,
        #decoder_block_type=config.decoder_block_type,
        classes=config.nb_class,
        activation=config.activation)
elif config.model == "FPN":
    model = FPN(
        backbone_name=config.backbone,
        encoder_weights=config.weights,
        #decoder_block_type=config.decoder_block_type,
        classes=config.nb_class,
        activation=config.activation)
else:
    raise None
model.compile(
    optimizer="Adam",  #optimizer=Adam(lr=1e-4, decay=5e-4)
    loss=bce_dice_loss,
    metrics=["binary_crossentropy", mean_iou, dice_coef])

# plot_model(model, to_file=os.path.join(model_path, config.exp_name+".png"))
if os.path.exists(os.path.join(model_path, config.exp_name + ".txt")):
    os.remove(os.path.join(model_path, config.exp_name + ".txt"))
with open(os.path.join(model_path, config.exp_name + ".txt"), 'w') as fh:
    model.summary(positions=[.3, .55, .67, 1.],
                  print_fn=lambda x: fh.write(x + '\n'))
                out = Conv2D(3,
                             kernel_size=3,
                             activation=active3,
                             padding='same',
                             kernel_initializer='he_normal')(l5)

                #    out= Conv2D(1, kernel_size=k_3, activation='relu', padding='same', kernel_initializer='he_normal',name="nothing")(out1)
                model = Model(inp, out, name='shallow')

            elif k_mod == "FPN":
                # N = x_train.shape[-1]
                if N == 3:
                    model = FPN(BACKBONE,
                                input_shape=(size, size, 3),
                                classes=3,
                                activation='softmax',
                                encoder_weights='imagenet',
                                encoder_freeze=False)
                else:
                    base_model = FPN(BACKBONE,
                                     input_shape=(size, size, 3),
                                     classes=3,
                                     activation='softmax',
                                     encoder_weights='imagenet',
                                     encoder_freeze=False)
                    inp = Input(shape=(size, size, N))
                    bn = BatchNormalization()(inp)
                    l1 = Conv2D(3, (1, 1))(
                        bn)  # map N channels data to 3 channels
                    out = base_model(l1)
                    model = Model(inp, out, name=base_model.name)
Exemplo n.º 15
0
from segmentation_models import get_preprocessing
from segmentation_models.losses import bce_jaccard_loss
from segmentation_models.metrics import iou_score
from tensorflow.keras.datasets import mnist


BACKBONE = 'resnet34'
preprocess_input = get_preprocessing(BACKBONE)




# load your data
(x_train, y_train),( x_val, y_val) = mnist.load_data()

# preprocess input
x_train = preprocess_input(x_train)
x_val = preprocess_input(x_val)

# define model
model = FPN(BACKBONE, input_shape=(224, 224, 6),classes=9, encoder_weights=None)
model.compile('Adam', loss=bce_jaccard_loss, metrics=[iou_score])

# fit model
model.fit(
    x=x_train,
    y=y_train,
    batch_size=16,
    epochs=100,
    validation_data=(x_val, y_val),
)
Exemplo n.º 16
0
if __name__ == '__main__':
    generator = SegDataGenerator(TRAIN_PATH,
                                 ANNO_PATH,
                                 batch_size=2,
                                 input_shape=(IMG_HEIGHT, IMG_WIDTH, 3),
                                 mask_shape=(IMG_HEIGHT, IMG_WIDTH,
                                             len(CLASS_COLOR)),
                                 preprocessing_function=make_aug,
                                 classes_colors=CLASS_COLOR,
                                 prob_aug=1)
    input_layer = Input(shape=(IMG_HEIGHT, IMG_WIDTH, 3))
    model = FPN(backbone_name=args.backbone,
                input_tensor=input_layer,
                encoder_weights='imagenet',
                classes=len(CLASS_COLOR),
                use_batchnorm=True,
                dropout=0.25,
                activation='softmax')

    save_name = 'weights/' + args.backbone + '.h5'
    callbacks_list = [
        ModelCheckpoint(save_name,
                        monitor='loss',
                        verbose=1,
                        save_best_only=True,
                        mode='min',
                        save_weights_only=True),
        ReduceLROnPlateau(monitor='loss', factor=0.2, patience=2, min_lr=1e-5)
    ]
Exemplo n.º 17
0
    def train(self):

        seed = 42
        # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
        # config = tf.ConfigProto(gpu_options=gpu_options)
        session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=True, per_process_gpu_memory_fraction=0.6))
        session = tf.Session(config=session_config)

        from segmentation_models import get_preprocessing

        self.processer = get_preprocessing(self.backbone)

        X_train = self.processer(self.X_train)

        Y_train = self.Y_train

        print('Done!')

        from tensorflow.keras.preprocessing import image

        # Creating the training Image and Mask generator
        image_datagen = image.ImageDataGenerator(shear_range=0.5,
                                                 rotation_range=50,
                                                 zoom_range=0.2,
                                                 width_shift_range=0.2,
                                                 height_shift_range=0.2,
                                                 fill_mode='reflect')
        mask_datagen = image.ImageDataGenerator(shear_range=0.5,
                                                rotation_range=50,
                                                zoom_range=0.2,
                                                width_shift_range=0.2,
                                                height_shift_range=0.2,
                                                fill_mode='reflect')

        # Keep the same seed for image and mask generators so they fit together

        image_datagen.fit(X_train[:int(X_train.shape[0] * 0.9)],
                          augment=True,
                          seed=seed)
        mask_datagen.fit(Y_train[:int(Y_train.shape[0] * 0.9)],
                         augment=True,
                         seed=seed)

        x = image_datagen.flow(X_train[:int(X_train.shape[0] * 0.9)],
                               batch_size=self.BATCH_SIZE,
                               shuffle=True,
                               seed=seed)
        y = mask_datagen.flow(Y_train[:int(Y_train.shape[0] * 0.9)],
                              batch_size=self.BATCH_SIZE,
                              shuffle=True,
                              seed=seed)

        # Creating the validation Image and Mask generator
        image_datagen_val = image.ImageDataGenerator()
        mask_datagen_val = image.ImageDataGenerator()

        image_datagen_val.fit(X_train[int(X_train.shape[0] * 0.9):],
                              augment=True,
                              seed=seed)
        mask_datagen_val.fit(Y_train[int(Y_train.shape[0] * 0.9):],
                             augment=True,
                             seed=seed)

        x_val = image_datagen_val.flow(X_train[int(X_train.shape[0] * 0.9):],
                                       batch_size=self.BATCH_SIZE,
                                       shuffle=True,
                                       seed=seed)
        y_val = mask_datagen_val.flow(Y_train[int(Y_train.shape[0] * 0.9):],
                                      batch_size=self.BATCH_SIZE,
                                      shuffle=True,
                                      seed=seed)

        train_generator = zip(x, y)
        val_generator = zip(x_val, y_val)

        from segmentation_models import Unet, PSPNet, Linknet, FPN
        from segmentation_models.losses import CategoricalFocalLoss
        from segmentation_models.utils import set_trainable
        import segmentation_models
        from tensorflow.keras.optimizers import RMSprop, SGD
        #model = self.model(self.IMG_HEIGHT,self.IMG_WIDTH,self.IMG_CHANNELS)

        if self.architecture == 'Linknet':

            self.model = Linknet(self.backbone,
                                 classes=self.num_bodyparts + 1,
                                 activation='softmax',
                                 encoder_weights=self.image_net,
                                 input_shape=(self.IMG_WIDTH, self.IMG_HEIGHT,
                                              self.IMG_CHANNELS))

        elif self.architecture == 'unet':

            self.model = Unet(self.backbone,
                              classes=self.num_bodyparts + 1,
                              activation='softmax',
                              encoder_weights=self.image_net,
                              input_shape=(self.IMG_WIDTH, self.IMG_HEIGHT,
                                           self.IMG_CHANNELS))
        elif self.architecture == 'PSPnet':
            self.model = PSPNet(self.backbone,
                                classes=self.num_bodyparts + 1,
                                activation='softmax',
                                encoder_weights=self.image_net,
                                input_shape=(self.IMG_WIDTH, self.IMG_HEIGHT,
                                             self.IMG_CHANNELS))

        elif self.architecture == 'FPN':
            self.model = FPN(self.backbone,
                             classes=self.num_bodyparts + 1,
                             activation='softmax',
                             encoder_weights=self.image_net,
                             input_shape=(self.IMG_WIDTH, self.IMG_HEIGHT,
                                          self.IMG_CHANNELS))

        weights = np.zeros((1, self.num_bodyparts + 1), dtype=float)
        weight = 1.0 / self.num_bodyparts

        num_zeros = 1
        # while (weight * 100 < 1):
        #     weight = weight * 100
        #     num_zeros += 1
        #
        # weight = int(weight * 100) / np.power(100, num_zeros)
        weights[0, 1:] = weight
        weights[0, 0] = 0.01 * len(self.bodyparts)

        while weights[0, 0] > weights[0, 1]:
            weights[0, 0] = weights[0, 0] / 10
            num_zeros += 1

        for i in range(1, len(self.bodyparts) + 1):
            weights[0, i] = weights[0, i] - 10**-(num_zeros + 1)

        if self.loss_function == "Weighted Categorical_cross_entropy":
            loss = self.weighted_categorical_crossentropy(weights)
        else:
            loss = segmentation_models.losses.DiceLoss(class_weights=weights)
        metric = segmentation_models.metrics.IOUScore(class_weights=weights,
                                                      per_image=True)
        self.model.compile(optimizer=RMSprop(lr=self.learning_rate),
                           loss=loss,
                           metrics=[metric])
        earlystopper = EarlyStopping(patience=6, verbose=1)
        #
        checkpointer = ModelCheckpoint(os.path.join(self.address, 'Unet.h5'),
                                       verbose=1,
                                       save_best_only=True)
        reduce_lr = keras.callbacks.LearningRateScheduler(self.lr_scheduler)

        #
        # model.fit_generator(train_generator, validation_data=val_generator, validation_steps=10, steps_per_epoch=50,
        #                                epochs=2, callbacks=[earlystopper, checkpointer],verbose=1)
        # model.load_weights(self.address + 'Temp_weights.h5')

        # set_trainable(model)
        #
        self.model.fit_generator(
            train_generator,
            validation_data=val_generator,
            steps_per_epoch=20,
            validation_steps=5,
            epochs=100,
            callbacks=[earlystopper, checkpointer, reduce_lr],
            verbose=1)
    assert len(train_generator.cat_ids) == len(val_generator.cat_ids)
else:
    assert False

# preprocess input
# from segmentation_models.backbones import get_preprocessing
# preprocess_input = get_preprocessing(BACKBONE)
# x_train = preprocess_input(x_train)
# x_val = preprocess_input(x_val)

# define model

#model = Unet(BACKBONE, classes=number_of_classes, encoder_weights='imagenet',activation='softmax')
# model = Linknet(BACKBONE, classes=number_of_classes, encoder_weights='imagenet')

model = FPN(BACKBONE, classes=number_of_classes, encoder_weights='imagenet',activation='softmax')
# model = PSPNet(BACKBONE, classes=number_of_classes, encoder_weights='imagenet')



# for multiclass segmentation choose another loss and metric
model.compile('Adam', loss='categorical_crossentropy', metrics=['categorical_accuracy'])

#model.compile('Adam', loss=cce_jaccard_loss, metrics=[jaccard_score])
model.summary()


#from keras.utils import plot_model
#plot_model(model, to_file='model.png')

modelCheckpoint = keras.callbacks.ModelCheckpoint(filepath=model_checkpoint_prefix+'_weights.{epoch:02d}-{val_loss:.4f}.hdf5',
Exemplo n.º 19
0
reduceLROnPlateau = keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, verbose=1,
                                                      mode='auto', min_delta=0.01, cooldown=0, min_lr=10e-7)

''' 
Some times, it is useful to train only randomly initialized decoder in order not to damage weights of properly trained encoder with huge gradients during first steps of training. 
In this case, all you need is just pass freeze_encoder = True argument while initializing the model.
'''
freeze_encoder = True
if os.name == 'nt':
    freeze_encoder = False

if number_of_classes > 1:
    if architecture == 'PSP':
        model = PSPNet(backbone, input_shape=dim_image, classes=number_of_classes, encoder_weights='imagenet', activation='softmax', freeze_encoder=freeze_encoder)
    elif architecture == 'FPN':
        model = FPN(backbone, input_shape=dim_image, classes=number_of_classes, encoder_weights='imagenet', activation='softmax', freeze_encoder=freeze_encoder)
    else:
        assert False
    model.compile('Adam', loss='categorical_crossentropy', metrics=['categorical_accuracy'])
    model.summary()
else:
    assert 1 == number_of_classes
    if architecture == 'PSP':
        model = PSPNet(backbone, input_shape=dim_image, classes=1, encoder_weights='imagenet', activation='sigmoid', freeze_encoder=freeze_encoder)
    elif architecture == 'FPN':
        model = FPN(backbone, input_shape=dim_image, classes=1, encoder_weights='imagenet', activation='sigmoid', freeze_encoder=freeze_encoder)
    elif architecture == 'Linknet':
        model = Linknet(backbone, input_shape=dim_image, classes=1, encoder_weights='imagenet', activation='sigmoid', freeze_encoder=freeze_encoder)
    elif architecture == 'Unet':
        model = Unet(backbone, input_shape=dim_image, classes=1, encoder_weights='imagenet', activation='sigmoid', freeze_encoder=freeze_encoder)
    else:
model = None
if model_type == "unet":
    model = Unet(
        model_backbone,
        classes=2,
        input_shape=(input_size[0], input_size[1], 3),
        encoder_weights=model_weights,
        decoder_filters=unet_filters,
        activation="softmax",
    )
elif model_type == "fpn":
    model = FPN(
        model_backbone,
        classes=2,
        input_shape=(input_size[0], input_size[1], 3),
        encoder_weights=model_weights,
        pyramid_block_filters=fpn_filters,
        pyramid_dropout=fpn_dropout,
        activation="softmax",
    )

model.summary()

model.compile(
    optimizer=optimizer,
    loss=globals()[loss],
    metrics=["accuracy", dice_coef, segmentation_models.metrics.iou_score],
)

# Choose checkpoint path
checkpoint_path = "./Weights/"
import tensorflow as tf
import numpy as np
from segmentation_models import FPN
import matplotlib.pyplot as plt

model = FPN(
    backbone_name="mobilenetv2",
    input_shape=(None, None, 3),
    classes=7,
    activation="sigmoid",
    weights=None,
    encoder_weights="imagenet",
    encoder_features="default",
    pyramid_block_filters=256,
    pyramid_use_batchnorm=True,
    pyramid_aggregation="concat",
    pyramid_dropout=None,
)
image = np.load("./data/image.npy")
label = np.load("./data/label.npy")

model.compile(
    loss=lambda labels, predictions: tf.keras.losses.binary_crossentropy(
        labels[:, :, :, 1:], predictions
    ),
    optimizer="adam",
    metrics=["accuracy"],
)

model.fit(
    [image],
Exemplo n.º 22
0
####################################################
############# Set model parameters #################
####################################################

if model_name == 'Unet':
    model = Unet(backbone_name=backbone_name,
                 classes=n_classes,
                 activation='softmax')
elif model_name == 'PSPNet':
    model = PSPNet(backbone_name=backbone_name,
                   classes=n_classes,
                   activation='softmax')
elif model_name == 'FPN':
    model = FPN(backbone_name=backbone_name,
                classes=n_classes,
                activation='softmax')
elif model_name == 'Linknet':
    model = Linknet(backbone_name=backbone_name,
                    classes=n_classes,
                    activation='softmax')
else:
    print('Please provide the right model name')

model.compile('Adam',
              loss='categorical_crossentropy',
              metrics=['categorical_accuracy'])

####################################################
############# Training model #######################
####################################################