class Deepy3d(object):
    """
    CT-image bone segmentation class.

    Calls a configuration file reader, reads in patient CT-scan data and trains
    a convolutional neural network. After training, it can perform bone
    segmentation on new patients and visualise its predictions, its learned
    kernels and the activations of the new scan in each layer.
    """
    def __init__(self, config_file):
        """Initialize Deepy3d class with a configuration file."""
        self.config = ConfigReader(config_file)

    def process_training_data(self):
        """Read and save the CT slices in another format."""
        print('* Reading CT scan files.')

        # Extract file directories
        trn_files = self.config.get_CT_scans()
        lbl_files = self.config.get_CT_labels()
        thr_files = self.config.get_CT_thresholded()

        # Iterate over patients
        for i in range(self.config.get_num_patients()):

            # Call a PatientIO instance for the i-th file
            patient = PatientIO(trn_files[i])

            # Save training data (.png) and dir of training data
            patient.save_scan(i, self.config.get_trn_CT_slice_PNG_dir(),
                              self.config.get_trn_CT_slice_NPY_dir())

            # Save labels (.png) to label directory
            patient.read_save_labels(i, lbl_files[i],
                                     self.config.get_trn_label_slice_PNG_dir(),
                                     self.config.get_trn_label_slice_NPY_dir(),
                                     thr_files[i],
                                     self.config.get_thr_label_slice_PNG_dir(),
                                     self.config.get_thr_label_slice_NPY_dir())

    def acquire_patches(self, balanced=True):
        """Extract patches from slices of CT scans."""
        print('* Extracting patches from CT-scan slices.')

        # Find all numpy arrays in directories
        trn_CTs_slices = sorted(
            glob(self.config.get_trn_CT_slice_NPY_dir() + '*.npy'))
        trn_lbl_slices = sorted(
            glob(self.config.get_trn_label_slice_NPY_dir() + '*.npy'))
        thr_lbl_slices = sorted(
            glob(self.config.get_thr_label_slice_NPY_dir() + '*.npy'))

        # Call an instance of Scans
        scans = Scans(trn_CTs_slices, trn_lbl_slices, thr_lbl_slices)

        # Return patches sampled from scans
        return scans.sample_patches(classes=self.config.get_classes(),
                                    patch_size=self.config.get_patch_size(),
                                    num_patches=self.config.get_num_patches(),
                                    edges=self.config.get_edges(),
                                    balanced=balanced)

    def initialise_network(self):
        """Initialize a network architecture."""
        print('* Initializing network.')

        # Construct an optimizer
        if self.config.get_optimizer() == 'SGD':
            opt = ks.optimizers.SGD(lr=self.config.get_learning_rate(),
                                    decay=self.config.get_decay(),
                                    momentum=self.config.get_momentum(),
                                    nesterov=self.config.get_Nesterov())

        elif self.config.get_optimizer() == 'RMSprop':
            opt = ks.optimizers.RMSprop(lr=self.config.get_learning_rate(),
                                        rho=self.config.get_rho(),
                                        epsilon=self.config.get_epsilon(),
                                        decay=self.config.get_decay())

        else:
            raise ValueError('Optimizer type not supported.')

        # Initialize model
        self.model = CNN_Model(optimizer=opt,
                               patch_size=self.config.get_patch_size(),
                               num_epochs=self.config.get_num_epochs(),
                               num_classes=len(self.config.get_classes()),
                               batch_size=self.config.get_batch_size(),
                               weight_reg=self.config.get_weight_reg()[0],
                               dropout=self.config.get_dropout(),
                               num_filters=self.config.get_num_filters(),
                               kernel_dims=self.config.get_kernel_dims(),
                               activation=self.config.get_activation())

        self.model.compile_single()

    def cross_evaluation(self,
                         X,
                         Y,
                         Yt,
                         patient_index,
                         num_folds=2,
                         predict_im=False):
        """
        Train and evaluate the model on a patient hold-out basis.

        Call the CNN_Model's cross-validation method and report metrics of
        interest.
        """
        # Start cross-validation procedure
        acc, preds = self.model.cross_validate(
            X, Y, patient_index, num_folds=self.config.get_num_folds())

        # Map label one-hot encoding to label vector
        y = np.argmax(Y, axis=1)
        yt = np.argmax(Yt, axis=1)

        # Report performance of the model
        print('* Classification report CNN:')
        print(classification_report(y, preds))
        print('Accuracy model = ', acc)

        # Report performance of thresholding for comparison
        print('* Classification report for thresholding:')
        print(classification_report(y, yt))
        print('Accuracy thresholding = ', np.mean(y == yt, axis=0))