Example #1
0
	def __init__ (self) :
		args = config.parse_arguments()
		self.roi_model=args.roi_model
		self.main_model=args.main_model
		self.load_weights_roi=args.load_weights_roi
		self.load_weights_main=args.load_weights_main		
		self.store_model=args.store_txt
		self.batch_size=args.batch_size
Example #2
0
 def __init__(self):
     args = config.parse_arguments()
     self.store_model_path = args.store_txt
     self.classes = args.classes
     self.loss_main = args.loss_main
     self.loss_roi = args.loss_roi
     self.width = args.width
     self.height = args.height
Example #3
0
 def __init__(self, init1, init2):
     args = config.parse_arguments()
     self.roi_activation = args.roi_activation
     self.init_f = init1
     self.init_b = init2
     self.loss_pre = args.loss_pre
     self.roi_shape = args.roi_shape_roi
     self.roi_optimizer = args.pre_optimizer
     self.get_history = False
     self.verbose = 1
     self.units_ed = 64
     self.pretrain_window = args.pretrain_window
     self.epoch_pre = args.epochs_pre
     self.monitor = args.monitor_callbacks
     self.store_model = args.store_txt
     self.weight_name = 'first'
Example #4
0
    def __init__(self, X, Y):
        self.mask = Y
        self.images = X
        args = config.parse_arguments()
        self.max = args.max_loops
        self.featurewise_center = args.featurewise_center
        self.samplewise_center = args.samplewise_center
        self.featurewise_std_normalization = args.featurewise_std_normalization
        self.samplewise_std_normalization = args.samplewise_std_normalization
        self.zca_whitening = args.zca_whitening
        self.rotation_range = args.rotation_range
        self.width_shift_range = args.width_shift_range
        self.height_shift_range = args.height_shift_range
        self.horizontal_flip = args.horizontal_flip
        self.vertical_flip = args.vertical_flip
        self.data_augm = args.data_augm_classic
        self.alpha = args.alpha
        self.sigma = args.sigma
        self.normilize = args.normalize
        self.shuffle = args.shuffle
        self.batch_size = args.batch_size
        self.index = np.arange(len(self.images))
        self.noise = args.noise
        self.random_apply_in_batch = args.random_apply_in_batch

        if self.data_augm == 'True':
            self.datagen = ImageDataGenerator(
                featurewise_center=self.
                featurewise_center,  # set input mean to 0 over the dataset
                samplewise_center=self.
                samplewise_center,  # set each sample mean o 0
                featurewise_std_normalization=self.
                featurewise_std_normalization,  # divide inputs by std of the dataset
                samplewise_std_normalization=self.
                samplewise_std_normalization,  # divide each input by its std
                zca_whitening=self.zca_whitening,  # apply ZCA whitening
                rotation_range=self.
                rotation_range,  # randomly rotate images in the range (degrees, 0 to 180)
                width_shift_range=self.
                width_shift_range,  # randomly shift images horizontally (fraction of total width)
                height_shift_range=self.
                height_shift_range,  # randomly shift images vertically (fraction of total height)
                horizontal_flip=self.horizontal_flip,  # randomly flip images
                vertical_flip=self.vertical_flip,
                zca_epsilon=1e-6)
Example #5
0
 def __init__(self, init1, init2, height, channels, classes, width):
     args = config.parse_arguments()
     self.height = height
     self.channels = channels
     self.classes = classes
     self.main_activation = args.main_activation
     self.init_w = init1
     self.init_b = init2
     self.features = args.features
     self.depth = args.depth
     self.padding = args.padding
     self.batchnorm = args.batchnorm
     self.dropout = args.dropout
     self.width = width
     self.temperature = 1.0
     self.max_norm_const = args.max_norm_const
     self.max_norm_value = args.max_norm_value
     self.im_length = args.height
Example #6
0
    def __init__(self, rmn, mmn, data_return='off', save_mode='off'):
        args = config.parse_arguments()
        self.roi_model_name = rmn
        self.main_model_name = mmn
        self.data_return = data_return
        self.save_mode = save_mode
        self.STORE_PATH = args.store_data_test
        self.STORE_PATH1 = self.STORE_PATH + '/ROI/train/'
        self.STORE_PATH2 = self.STORE_PATH + '/ROI/test/'
        self.STORE_PATH_main1 = self.STORE_PATH + '/MAIN/train/'
        self.STORE_PATH_main2 = self.STORE_PATH + '/MAIN/test/'
        # if the paths does not excist create automatic
        #		if not os.path.exists(self.STORE_PATH1):
        #			os.makedirs(self.STORE_PATH1)
        #		if not os.path.exists(self.STORE_PATH2):
        #			os.makedirs(self.STORE_PATH2)
        #		if not os.path.exists(self.STORE_PATH_main1):
        #			os.makedirs(self.STORE_PATH_main1)
        #		if not os.path.exists(self.STORE_PATH_main2):
        #			os.makedirs(self.STORE_PATH_main2)
        self.data_extention = args.data_extention
        self.counter_extention = args.counter_extention
        self.restore_from_jpg_tif = args.restore_image

        self.label_classes = args.label_classes
        if self.label_classes == 'both':
            self.classes = 2
        if self.label_classes == 'three':
            self.classes = 3
        if self.label_classes == 'four':
            self.classes = 4
        if self.label_classes == 'first':
            self.classes = 1
        if self.label_classes == 'second':
            self.classes = 1
        if self.label_classes == 'third':
            self.classes = 1
        if self.label_classes == 'fourth':
            self.classes = 1
Example #7
0
    def __init__(self, X, Y, case):
        args = config.parse_arguments()
        self.case = case
        self.gan_train_directory = args.gan_train_directory
        self.data_augm = args.data_augm
        self.batch_size = args.batch_size
        self.batch_size_test = args.batch_size_test
        self.num_cores = args.num_cores
        self.validation_split = args.validation_split
        self.validation = args.validation
        self.shuffle = args.shuffle
        self.normalize_image = args.normalize
        self.gancheckpoint = 'checkpoint'
        self.gan_synthetic = args.gan_synthetic
        self.num_synthetic_images = args.num_synthetic_images
        self.batch_size = args.batch_size
        self.X = X
        self.Y = Y
        self.fourier = args.fft_convert_data
        self.STORE_PATH = args.store_data_test
        self.STORE_PATH10 = self.STORE_PATH + '/ROI/train/'
        self.STORE_PATH1 = self.STORE_PATH + '/FFT/train/'
        self.data_extention = "jpeg"
        boolt = 'False'
        self.main_model = args.main_model
        if (args.main_model == 'rgu_net' and self.case == 'main'):
            A, nodes_coordinates = grid_graph(args.height)
            coarsening_levels = args.depth
            u_shape, u_rows, u_cols, u_val, perm = coarsen_mnist(
                A, coarsening_levels, nodes_coordinates)
            self.X = convert_train_data(X, perm, args.height)
            print(
                'the modified shape of X input because of RGMM structure is: ')
            print(self.X.shape)

        if boolt == 'True':
            self.threshold_bin = 0.8
        else:
            self.threshold_bin = 0.05
Example #8
0
import sys
import cv2
import matplotlib.pyplot as plt
from auto_segm import run_model, config
from keras.losses import mean_squared_error
import numpy as np
import itertools
from keras.models import Sequential
from keras.layers import Dense
from keras.utils import to_categorical
from keras.wrappers.scikit_learn import KerasClassifier


class vision_outputs:
    def __init__(self, case, history=None):
        args = config.parse_arguments()
        self.h = history
        self.STORE_TXT = args.store_txt
        self.store_results = args.store_results
        self.validation = args.validation
        self.metrics = args.metrics
        self.metrics1 = args.metrics1
        self.metrics2 = args.metrics2
        self.case = case

    def plot_loss(self):
        '''
			plot the loss results 
		'''
        print(self.h.history.keys())
        if (self.validation == "on"):
Example #9
0
    def __init__(self, path_case):

        args = config.parse_arguments()
        self.model_name = args.main_model
        self.ram = args.ram
        self.cross_validation_number = args.crossval_cycle
        self.ngpu = args.ngpu
        self.metrics = args.metrics
        self.metrics1 = args.metrics1
        self.metrics2 = args.metrics2
        self.batch_size = args.batch_size
        self.batch_size_test = args.batch_size_test
        self.epochs_roi = args.epochs_roi
        self.epochs_main = args.epochs_main
        self.num_cores = args.num_cores
        self.path_case = path_case
        self.validation_split = args.validation_split
        self.validation = args.validation
        self.shuffle = args.shuffle
        self.optimizer = args.pre_optimizer
        self.weights = args.loss_weights
        self.normalize_image = args.normalize
        self.roi_shape = args.roi_shape_roi
        self.store_model = args.store_txt
        self.weight_name = 'first'
        self.early_stopping = args.early_stop
        self.monitor = args.monitor_callbacks
        self.mode = args.mode_convert
        self.label_classes = args.label_classes
        self.exponential_decay = 'False'
        self.lrate = args.learning_rate
        self.store_model_path = args.store_txt
        self.fourier = args.fft_convert_data
        if (args.decay == 666):
            self.exponential_decay = 'True'
            args.decay = 0

        optimizer_args_roi = {
            'lr': args.roi_learning_rate,
            'momentum': args.roi_momentum,
            'decay': args.roi_decay,
            'seed': args.roi_seed
        }
        optimizer_args = {
            'lr': args.learning_rate,
            'momentum': args.momentum,
            'decay': args.decay,
            'seed': args.seed
        }
        if self.path_case == 'roi':
            for k in list(optimizer_args_roi):
                if optimizer_args_roi[k] is None:
                    del optimizer_args_roi[k]
            optimizer = self.pass_optimizer(args.roi_optimizer,
                                            optimizer_args_roi)
            self.optimizer = optimizer
            self.epochs = self.epochs_roi

        if self.path_case == 'main':
            for k in list(optimizer_args):
                if optimizer_args[k] is None:
                    del optimizer_args[k]
            optimizer = self.pass_optimizer(args.m_optimizer, optimizer_args)
            self.optimizer = optimizer
            self.epochs = self.epochs_main
Example #10
0
	def __init__ (self,analysis,path_case) :
		"""
		Initializare of the config file

		"""
		args = config.parse_arguments()
		self.path_case=path_case
		self.restore_from_jpg_tif=args.restore_image
		self.rotated="False"
		if self.path_case=='roi':
			self.image_shape=args.image_shape_roi
			self.original_image_shape=args.original_image_shape_roi
			self.roi_shape=args.roi_shape_roi
			self.data_path=args.datapath
			self.data_path2=args.datapath
			self.STORE_TXT=args.store_txt
			self.counter_path='/contour/'
			self.data_extention = args.data_extention_roi
			self.counter_extention = args.counter_extention_roi
			self.PATH_IMAGES='/image'
			self.PATH_IMAGES2='/image'	
		if self.path_case=='main':
			if self.restore_from_jpg_tif=='off':	
				self.image_shape=args.image_shape
				self.original_image_shape=args.original_image_shape
				self.roi_shape=args.roi_shape
				self.data_path=args.datapath
				self.data_path2=args.datapath
				self.STORE_TXT=args.store_txt
				self.counter_path='/contour/'
				self.data_extention = args.data_extention_roi
				self.counter_extention = args.counter_extention_roi
				self.PATH_IMAGES='/image'
				self.PATH_IMAGES2='/image'				
			else:
				self.image_shape=args.image_shape
				self.original_image_shape=args.original_image_shape
				self.roi_shape=args.roi_shape
				self.data_path=args.store_data_test
				self.data_path2=args.datapath
				self.STORE_TXT=args.store_txt
				self.counter_path='/contour_main/'
				self.data_extention = args.data_extention
				self.counter_extention = args.counter_extention
				self.PATH_IMAGES='/ROI/train'
				self.PATH_IMAGES2='/ROI/test'
		if self.path_case=='pre':
			self.image_shape=args.image_shape_pre
			self.original_image_shape=args.original_image_shape_pre
			self.roi_shape=args.roi_shape_pre
			self.data_path=args.datapath
			self.data_path2=args.datapath
			self.STORE_TXT=args.store_txt
			self.counter_path='/contour/'
			self.data_extention = args.data_extention_pre
			self.counter_extention = args.counter_extention_pre		
			self.PATH_IMAGES='/image'
			self.PATH_IMAGES2='/image'
			self.pretrain_window=args.pretrain_window
		
		# seperate the train of ROI with the train set for the u_net. Thus take the epi and endo seperate in u_net and train the ROI detection in both
		if (analysis=='train' or analysis=='test') :		
			self.n_set_pre=analysis
			self.n_set=analysis
		else:
			self.n_set_pre='train_prediction'
			self.n_set='train'

		self.patient_list=args.patient_list
		self.store_contour=args.store_data_test
		self.image_part = np.zeros([self.original_image_shape,self.original_image_shape])
		self.shuffle=args.shuffle
		self.num_preprocess_threads=args.num_cores
		self.batch_size=args.batch_size
		self.patient_store_style=args.patient_store_style
		self.STORE_PATH=args.store_data_test
		#if MICCAI_2009:
		self.label_case1=args.label_case_extension_1
		self.label_case2=args.label_case_extension_2
		self.label_classes=args.label_classes
		if self.label_classes=='both':
			self.classes=2
		elif self.label_classes=='three':
			self.classes=3
		elif self.label_classes=='four':
			self.classes=4
		elif self.label_classes=='first':
			self.classes=1
		elif self.label_classes=='second':
			self.classes=1
		elif self.label_classes=='third':
			self.classes=1
		elif self.label_classes=='fourth':
			self.classes=1	
		else:
			print("No accebtable label_classes!! ")