Esempio n. 1
0
def main(args):
    # Ensure training, testing, and manip are not all turned off
    assert (args.train or args.test or args.manip),\
        'Cannot have train, test, and manip all set to 0, Nothing to do.'

    # Load the training, validation, and testing data
    try:
        train_list, val_list, test_list = load_data(args.data_root_dir,
                                                    args.split_num)
    except IndexError:
        # Create the training and test splits if not found
        split_data(args.data_root_dir, num_splits=4)
        train_list, val_list, test_list = load_data(args.data_root_dir,
                                                    args.split_num)

    # Get image properties from first image. Assume they are all the same.
    img_shape = sitk.GetArrayFromImage(
        sitk.ReadImage(join(args.data_root_dir, 'imgs',
                            train_list[0][0]))).shape
    net_input_shape = (img_shape[0], img_shape[1], args.slices)
    print(net_input_shape)
    # Create the model for training/testing/manipulation
    model_list = create_model(args=args, input_shape=net_input_shape)
    print_summary(model=model_list[0], positions=[.38, .65, .75, 1.])

    args.output_name = args.save_prefix +\
        'split-' + str(args.split_num) + \
        '-batch-' + str(args.batch_size) + \
        '_shuff-' + str(args.shuffle_data) + \
        '_aug-' + str(args.aug_data) + \
        '_loss-' + str(args.loss) + \
        '_strid-' + str(args.stride) + \
        '_lr-' + str(args.initial_lr) + \
        '_recon-' + str(args.recon_wei)

    args.time = time

    args.check_dir = join(args.data_root_dir, 'saved_models', args.net)
    try:
        makedirs(args.check_dir)
    except FileExistsError:
        pass

    args.log_dir = join(args.data_root_dir, 'logs', args.net)
    try:
        makedirs(args.log_dir)
    except FileExistsError:
        pass

    args.tf_log_dir = join(args.log_dir, 'tf_logs')
    try:
        makedirs(args.tf_log_dir)
    except FileExistsError:
        pass

    args.output_dir = join(args.data_root_dir, 'plots', args.net)
    try:
        makedirs(args.output_dir)
    except FileExistsError:
        pass

    if args.train:
        from train import train
        # Run training
        train(args, train_list, val_list, model_list[0], net_input_shape)

    if args.test:
        from test import test
        # Run testing
        test(args, test_list, model_list, net_input_shape)
Esempio n. 2
0
def main(args):
    #tf.enable_eager_execution()

    #sess = K.get_session()
    #sess = tf_debug.LocalCLIDebugWrapperSession(sess)
    #K.set_session(sess)

    # Ensure training, testing, and manip are not all turned off
    assert (args.train or args.test or args.manip), 'Cannot have train, test, and manip all set to 0, Nothing to do.'

    # Load the training, validation, and testing data
    try:
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)
    except:
        # Create the training and test splits if not found
        split_data(args.data_root_dir, num_splits=4)
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)

    # Get image properties from first image. Assume they are all the same.
    img_shape = sitk.GetArrayFromImage(sitk.ReadImage(join(args.data_root_dir, 'imgs', train_list[0][0]))).shape
    net_input_shape = (img_shape[1], img_shape[2], args.slices)

    if args.pytorch:
        from capsnet_pytorch import CapsNetBasic, CapsNetR3
        model = CapsNetR3() #CapsNetBasic()
    else:
        # Create the model for training/testing/manipulation
        model_list = create_model(args=args, input_shape=net_input_shape)
        print_summary(model=model_list[0], positions=[.38, .65, .75, 1.])

    args.output_name = 'split-' + str(args.split_num) + '_batch-' + str(args.batch_size) + \
                       '_shuff-' + str(args.shuffle_data) + '_aug-' + str(args.aug_data) + \
                       '_loss-' + str(args.loss) + '_slic-' + str(args.slices) + \
                       '_sub-' + str(args.subsamp) + '_strid-' + str(args.stride) + \
                       '_lr-' + str(args.initial_lr) + '_recon-' + str(args.recon_wei)
    args.time = time

    args.check_dir = join(args.data_root_dir,'saved_models', args.net)
    try:
        makedirs(args.check_dir)
    except:
        pass

    args.log_dir = join(args.data_root_dir,'logs', args.net)
    try:
        makedirs(args.log_dir)
    except:
        pass

    args.tf_log_dir = join(args.log_dir, 'tf_logs')
    try:
        makedirs(args.tf_log_dir)
    except:
        pass

    args.output_dir = join(args.data_root_dir, 'plots', args.net)
    try:
        makedirs(args.output_dir)
    except:
        pass

    if args.pytorch:
        if args.train:
            from train_pytorch import train
            # Run training
            train(args, model, train_list, net_input_shape)
            #train(args, train_list, val_list, model_list[0], net_input_shape)
        
        
    else:
        if args.train:
            from train import train
            # Run training
            train(args, train_list, val_list, model_list[0], net_input_shape)
    
        if args.test:
            from test import test
            # Run testing
            test(args, test_list, model_list, net_input_shape)
    
        if args.manip:
            from manip import manip
            # Run manipulation of segcaps
            manip(args, test_list, model_list, net_input_shape)
Esempio n. 3
0
def main(args):
    image_shape = (128, 128)
    # Ensure training, testing, and manip are not all turned off
    assert (args.train or args.test or args.manip), 'Cannot have train, test, and manip all set to 0, Nothing to do.'

    # Load the training, validation, and testing data
    try:
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)
    except:
        # Create the training and test splits if not found
        split_data(args.data_root_dir, num_splits=4)
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)

    # Get image properties from first image. Assume they are all the same.
    im = cv2.imread(join(args.data_root_dir, 'imgs', train_list[0][0]))
    im = cv2.resize(im, image_shape)
    gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
    gray = np.reshape(gray, (1, gray.shape[0], gray.shape[1]))
    img_shape = gray.shape
    net_input_shape = (img_shape[1], img_shape[2], args.slices)

    # Create the model for training/testing/manipulation
    model_list = create_model(args=args, input_shape=net_input_shape, class_num=6)
    print_summary(model=model_list[0], positions=[.38, .65, .75, 1.])
    args.output_name = 'split-' + str(args.split_num) + '_batch-' + str(args.batch_size) + \
                       '_shuff-' + str(args.shuffle_data) + '_aug-' + str(args.aug_data) + \
                       '_loss-' + str(args.loss) + '_slic-' + str(args.slices) + \
                       '_sub-' + str(args.subsamp) + '_strid-' + str(args.stride) + \
                       '_lr-' + str(args.initial_lr) + '_recon-' + str(args.recon_wei)
    args.time = time

    args.check_dir = join(args.data_root_dir, 'saved_models', args.net)
    try:
        makedirs(args.check_dir)
    except:
        pass

    args.log_dir = join(args.data_root_dir, 'logs', args.net)
    try:
        makedirs(args.log_dir)
    except:
        pass

    args.tf_log_dir = join(args.log_dir, 'tf_logs')
    try:
        makedirs(args.tf_log_dir)
    except:
        pass

    args.output_dir = join(args.data_root_dir, 'plots', args.net)
    try:
        makedirs(args.output_dir)
    except:
        pass

    if args.train:
        print('training')
        from train import train
        # Run training
        while ([] in train_list):
            train_list.remove([])
        CUDA_VISIBLE_DEVICES = 1
        train(args, train_list, val_list, model_list[0], net_input_shape)

    if args.test:
        from test import test
        # Run testing
        while ([] in test_list):
            test_list.remove([])
        for i in range(39,40):
            args.weights_path = f'saved_models/{args.net}/model' + str(i) + '.hdf5'
            test(args, test_list, model_list, net_input_shape, i)

    if args.manip:
        from manip import manip
        # Run manipulation of segcaps
        manip(args, test_list, model_list, net_input_shape)
Esempio n. 4
0
def main(args):
    # Ensure training, testing, and manip are not all turned off
    assert (args.train or args.test or args.manip), 'Cannot have train, test, and manip all set to 0, Nothing to do.'

    # Load the training, validation, and testing data
    try:
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)
    except:
        # Create the training and test splits if not found
        split_data(args.data_root_dir, num_splits=4)
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)

    # Get image properties from first image. Assume they are all the same.
    img_shape = sitk.GetArrayFromImage(sitk.ReadImage(join(args.data_root_dir, 'imgs', train_list[0][0]))).shape
    net_input_shape = (img_shape[1], img_shape[2], args.slices)

    # Create the model for training/testing/manipulation
    model_list = create_model(args=args, input_shape=net_input_shape)
    print_summary(model=model_list[0], positions=[.38, .65, .75, 1.])

    args.output_name = 'split-' + str(args.split_num) + '_batch-' + str(args.batch_size) + \
                       '_shuff-' + str(args.shuffle_data) + '_aug-' + str(args.aug_data) + \
                       '_loss-' + str(args.loss) + '_slic-' + str(args.slices) + \
                       '_sub-' + str(args.subsamp) + '_strid-' + str(args.stride) + \
                       '_lr-' + str(args.initial_lr) + '_recon-' + str(args.recon_wei)
    args.time = time

    args.check_dir = join(args.data_root_dir,'saved_models', args.net)
    try:
        makedirs(args.check_dir)
    except:
        pass

    args.log_dir = join(args.data_root_dir,'logs', args.net)
    try:
        makedirs(args.log_dir)
    except:
        pass

    args.tf_log_dir = join(args.log_dir, 'tf_logs')
    try:
        makedirs(args.tf_log_dir)
    except:
        pass

    args.output_dir = join(args.data_root_dir, 'plots', args.net)
    try:
        makedirs(args.output_dir)
    except:
        pass

    if args.train:
        from train import train
        # Run training
        train(args, train_list, val_list, model_list[0], net_input_shape)

    if args.test:
        from test import test
        # Run testing
        test(args, test_list, model_list, net_input_shape)

    if args.manip:
        from manip import manip
        # Run manipulation of segcaps
        manip(args, test_list, model_list, net_input_shape)
def main(args):
    # Ensure training, testing, and manip are not all turned off
    assert (args.train or args.test or args.pred), \
        'Cannot have train, test, and pred all set to 0, Nothing to do.'

    # Set the output name to the experiment name, testing split, and all relevant hyperparameters for easy reference.
    args.output_name = 'exp-{}_split-{}_batch-{}_shuff-{}_aug-{}_loss-{}_slic-{}_sub-{}_strid-{}_lr-{}_' \
                       'dpre-{}_fbase-{}_dis-{}'.format(
                        args.exp_name, args.split_num, args.batch_size, args.shuffle_data, args.aug_data, args.loss,
                        args.slices, args.subsamp, args.stride, args.initial_lr, args.use_default_pretrained,
                        args.freeze_base_weights, args.disentangled)
    args.time = time

    # Set the name for the fusion strategy
    if args.disentangled:
        fusion_type = 'disen'
    else:
        fusion_type = 'early'

    # Create the directory for saving model checkpoints
    args.check_dir = os.path.join(args.data_root_dir, 'saved_models', args.exp_name, args.net + '_' + fusion_type)
    try:
        os.makedirs(args.check_dir)
    except:
        pass

    # Create the directory for saving csv logs
    args.log_dir = os.path.join(args.data_root_dir, 'logs', args.exp_name, args.net + '_' + fusion_type)
    try:
        os.makedirs(args.log_dir)
    except:
        pass

    # Create the directory for saving TF logs
    args.tf_log_dir = os.path.join(args.log_dir, 'tf_logs')
    try:
        os.makedirs(args.tf_log_dir)
    except:
        pass

    # Create the directory for saving train/test plots
    args.output_dir = os.path.join(args.data_root_dir, 'plots', args.exp_name, args.net + '_' + fusion_type, args.time)
    try:
        os.makedirs(args.output_dir)
    except:
        pass

    # Load images for this split
    all_imgs_list = []
    if (args.train or args.test):
        train_list, val_list, test_list = load_data(root=args.data_root_dir, mod_dirs=args.modality_dir_list,
                                                    exp_name=args.exp_name, split=args.split_num,
                                                    k_folds=args.k_fold_splits, val_split=args.val_split,
                                                    rand_seed=args.rand_seed)

        # Print the images selected for validation
        all_imgs_list = all_imgs_list + list(train_list) + list(val_list) + list(test_list)
        print('\nFound a total of {} images with lables.'.format(len(all_imgs_list)))
        print('\t{} images for training.'.format(len(train_list)))
        print('\t{} images for validation.'.format(len(val_list)))
        print('\t{} images for testing.'.format(len(test_list)))
        print('\nRandomly selected validation images:')
        print(val_list)
        print('\n')

    # If the user selected to do predictions (no GT), load those prediction images
    if args.pred:
        with open(os.path.join(args.data_root_dir, 'split_lists', args.exp_name,
                               'pred_split_{}.csv'.format(args.split_num)), 'r') as f:
            reader = csv.reader(f)
            pred_list = list(reader)
        all_imgs_list = all_imgs_list + list(pred_list)

    # This creates all images up front instead of dynamically during training. Beneficial if paying for GPU hours.
    if args.create_all_imgs:
        print('-' * 98, '\nCreating all images... This will take some time.\n', '-' * 98)
        for img_pairs in tqdm(all_imgs_list):
            _, _ = convert_data_to_numpy(root_path=args.data_root_dir, img_names=img_pairs,
                                         mod_dirs=args.modality_dir_list, exp_name=args.exp_name,
                                         no_masks=False, overwrite=True)

    # Set the network input shape depending on using a 2D or 3D network.
    if args.net.find('3d') != -1 or args.net.find('inflated') != -1:
        net_input_shape = [None, None, args.slices, 3 * (len(args.modality_dir_list.split(',')) - 1)]
    else:
        net_input_shape = [None, None, 3 * (len(args.modality_dir_list.split(',')) - 1)]

    # Create the model for training/testing/manipulation
    train_shape = sitk.ReadImage(os.path.join(args.data_root_dir, all_imgs_list[0][0])).GetSize()[:-1]
    args.resize_shape = [args.resize_hei, args.resize_wid]
    if args.resize_shape[0] is not None:
        train_shape[0] = args.resize_shape[0]
    if args.resize_shape[1] is not None:
        train_shape[1] = args.resize_shape[1]

    # Update the network input shape based on our data size
    train_shape = (train_shape[0] // (2 ** 6) * (2 ** 6), train_shape[1] // (2 ** 6) * (2 ** 6))  # Assume 6 downsamples
    net_input_shape[0] = train_shape[0]
    net_input_shape[1] = train_shape[1]
    model = create_model(args=args, input_shape=net_input_shape)

    # Print the model summary (try except is for Keras version compatibility)
    try:
        from keras.utils import print_summary
        print_summary(model=model, positions=[.38, .69, .8, 1.])
    except:
        model.summary()

    if args.train:
        # Run training
        print('-'*98,'\nRunning Training\n','-'*98)
        from train_class import train
        train(args, train_list, val_list, model, net_input_shape)

    if args.test:
        # Run testing
        print('-'*98,'\nRunning Testing\n','-'*98)
        from test_class import test
        test(args, test_list, model, net_input_shape)

    if args.pred:
        # Run prediction on new data
        print('-'*98,'\nRunning Prediction\n','-'*98)
        from predict_class import predict
        predict(args, pred_list, model, net_input_shape)
Esempio n. 6
0
def main(args):
    # Ensure training, testing, and manip are not all turned off
    assert (args.train or args.test or args.manip), 'Cannot have train, test, and manip all set to 0, Nothing to do.'

    # Load the training, validation, and testing data
    try:
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)
    except:
        # Create the training and test splits if not found
        split_data(args.data_root_dir, num_splits=4)
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)

    # Get image properties from first image. Assume they are all the same.
    #img_shape = sitk.GetArrayFromImage(sitk.ReadImage(join(args.data_root_dir, 'imgs', train_list[0][0]))).shape
    im=cv2.imread(join(args.data_root_dir, 'imgs', train_list[0][0]))
    gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
    #im = cv2.imread(join(img_path, img_name), -1)
    #gray = im
    gray=np.reshape(gray, (1,gray.shape[0],gray.shape[1]))
    img_shape = gray.shape
    net_input_shape = (512, 512, args.slices)

    # Create the model for training/testing/manipulation
    model_list = create_model(args=args, input_shape=net_input_shape)
    print_summary(model=model_list[0], positions=[.38, .65, .75, 1.])
    model_list[0].load_weights('liver_data3/saved_models_original/segcapsr3/model5.hdf5')
    #print('model20')
    #model_list[0].load_weights('C:/Users/212673708/Documents/TAU/semesterB/DeepLearningInMedicalImaging/modelLiver.hdf5')
    args.output_name = 'split-' + str(args.split_num) + '_batch-' + str(args.batch_size) + \
                       '_shuff-' + str(args.shuffle_data) + '_aug-' + str(args.aug_data) + \
                       '_loss-' + str(args.loss) + '_slic-' + str(args.slices) + \
                       '_sub-' + str(args.subsamp) + '_strid-' + str(args.stride) + \
                       '_lr-' + str(args.initial_lr) + '_recon-' + str(args.recon_wei)
    args.time = time

    args.check_dir = join(args.data_root_dir,'saved_models', args.net)
    try:
        makedirs(args.check_dir)
    except:
        pass

    args.log_dir = join(args.data_root_dir,'logs', args.net)
    try:
        makedirs(args.log_dir)
    except:
        pass

    args.tf_log_dir = join(args.log_dir, 'tf_logs')
    try:
        makedirs(args.tf_log_dir)
    except:
        pass

    args.output_dir = join(args.data_root_dir, 'plots', args.net)
    try:
        makedirs(args.output_dir)
    except:
        pass

    if args.train:
        print('training')
        from train import train
        # Run training
        while ([] in train_list):
            train_list.remove([])
        CUDA_VISIBLE_DEVICES = 1
        train(args, train_list, val_list, model_list[0], net_input_shape)

    if args.test:
        from test import test
        # Run testing
        while ([] in test_list):
            test_list.remove([])
        i=5
        args.weights_path = 'saved_models_original/segcapsr3/model'+str(i)+'.hdf5'
        test(args, test_list, model_list, net_input_shape,i)

    if args.manip:
        from manip import manip
        # Run manipulation of segcaps
        manip(args, test_list, model_list, net_input_shape)