Ejemplo n.º 1
0
    def fit_one_batch(self,
                      training_data_collection,
                      output_model_filepath=None,
                      input_groups=None,
                      output_directory=None,
                      callbacks=['save_model', 'log'],
                      training_batch_size=16,
                      training_steps_per_epoch=None,
                      num_epochs=None,
                      show_results=False,
                      **kwargs):

        one_batch_generator = self.keras_generator(
            training_data_collection.data_generator(
                perpetual=True,
                data_group_labels=input_groups,
                verbose=False,
                just_one_batch=True,
                batch_size=training_batch_size))

        self.callbacks = get_callbacks(
            callbacks,
            output_model_filepath=output_model_filepath,
            data_collection=training_data_collection,
            model=self,
            batch_size=training_batch_size,
            backend='keras',
            **kwargs)

        if training_steps_per_epoch is None:
            training_steps_per_epoch = training_data_collection.total_cases // training_batch_size + 1

        try:
            self.model.fit_generator(generator=one_batch_generator,
                                     steps_per_epoch=training_steps_per_epoch,
                                     epochs=num_epochs,
                                     callbacks=self.callbacks)
        except KeyboardInterrupt:
            for callback in self.callbacks:
                callback.on_train_end()
        except:
            raise

        one_batch = next(one_batch_generator)
        prediction = self.predict(one_batch[0])

        if show_results:
            check_data(output_data={
                self.input_data: one_batch[0],
                self.targets: one_batch[1],
                'prediction': prediction
            },
                       batch_size=training_batch_size)

        return
Ejemplo n.º 2
0
    def on_epoch_end(self, data, logs={}):

        # Hacky, revise later.
        epoch = data[0]
        reference_data = data[1]

        if self.epoch_prediction_object is None:
            prediction = self.deepneuro_model.predict(
                sample_latent=self.sample_latent)
        else:
            prediction = self.epoch_prediction_object.process_case(
                self.predict_data[self.deepneuro_model.input_data],
                model=self.deepneuro_model)

        output_filepaths, output_images = check_data(
            {
                'prediction': prediction,
                'real_data': reference_data
            },
            output_filepath=os.path.join(self.depth_dir,
                                         'epoch_{}.png'.format(epoch)),
            show_output=False,
            batch_size=self.epoch_prediction_batch_size)

        self.predictions[-1] += [output_images['prediction'].astype('uint8')]

        return
Ejemplo n.º 3
0
    def on_epoch_end(self, epoch, logs={}):

        if self.epoch_prediction_object is None:
            prediction = self.deepneuro_model.predict(
                self.predict_data[self.deepneuro_model.input_data])
        else:
            prediction = self.epoch_prediction_object.process_case(
                self.predict_data, model=self.deepneuro_model)

        output_filepaths, output_images = check_data(
            {'prediction': prediction},
            output_filepath=os.path.join(self.epoch_prediction_dir,
                                         'epoch_{}.png'.format(epoch)),
            show_output=self.show_callback_output,
            batch_size=self.epoch_prediction_batch_size,
            **self.kwargs)

        if len(output_images.keys()) > 1:
            self.predictions += [[
                output_images['prediction_' + str(idx)].astype('uint8')
                for idx in range(len(output_images.keys()))
            ]]
        else:
            self.predictions += [output_images['prediction'].astype('uint8')]

        return
Ejemplo n.º 4
0
    def on_train_end(self, logs={}):

        if self.predictions != []:

            if self.epoch_prediction_output_mode == 'gif':

                if type(self.predictions[0]) is list:
                    for output in range(len(self.predictions[0])):
                        current_predictions = [
                            item[output] for item in self.predictions
                        ]
                        imageio.mimsave(
                            os.path.join(
                                self.epoch_prediction_dir,
                                'epoch_prediction_' + str(output) + '.gif'),
                            current_predictions)
                else:
                    imageio.mimsave(
                        os.path.join(self.epoch_prediction_dir,
                                     'epoch_prediction.gif'), self.predictions)

            elif self.epoch_prediction_output_mode == 'mosaic':

                raise NotImplementedError(
                    'Training callback mosaics are not yet implemented. (epoch_prediction_output_mode = \'mosaic\''
                )

                if type(self.predictions[0]) is list:
                    for output in range(len(self.predictions[0])):
                        current_predictions = [
                            item[output] for item in self.predictions
                        ]
                        prediction_array = np.array(current_predictions)
                        check_data({'Training Progress': prediction_array},
                                   show_output=True,
                                   **self.kwargs)
                        print(prediction_array.shape)
                        # imageio.mimsave(os.path.join(self.epoch_prediction_dir, 'epoch_prediction_' + str(output) + '.gif'), current_predictions)
                else:
                    output_mosaic = np.array(self.predictions)
                    print(output_mosaic.shape)

            else:
                raise NotImplementedError

        return
Ejemplo n.º 5
0
    def on_epoch_end(self, epoch, logs={}):

        if self.epoch_prediction_object is None:
            prediction = self.deepneuro_model.predict(sample_latent=self.sample_latent)
        else:
            prediction = self.epoch_prediction_object.process_case(self.predict_data[self.deepneuro_model.input_data], model=self.deepneuro_model)

        output_filepaths, output_images = check_data({'prediction': prediction}, output_filepath=os.path.join(self.epoch_prediction_dir, 'epoch_{}.png'.format(epoch)), show_output=False, batch_size=self.epoch_prediction_batch_size)

        self.predictions += [output_images['prediction'].astype('uint8')]

        return
Ejemplo n.º 6
0
    def save_output(self, postprocessor_idx=None, raw_data=None):

        # Currently assumes Nifti output. TODO: Make automatically detect output or determine with a class variable.
        # Ideally, split also this out into a saving.py function in utils.

        for input_data in self.return_objects:

            output_filenames = self.generate_filenames(raw_data,
                                                       postprocessor_idx)

            if self.output_extension in ['.csv']:
                self.save_to_csv(input_data, output_filenames)
            else:
                self.save_to_disk(input_data, output_filenames, raw_data)

            if self.show_output:
                # This function call will need to be updated as Outputs is extended for more data types.
                check_data({'prediction': input_data},
                           batch_size=1,
                           **self.kwargs)

        return