예제 #1
0
    def train(self,
              training_data_collection,
              validation_data_collection=None,
              output_model_filepath=None,
              input_groups=None,
              training_batch_size=32,
              validation_batch_size=32,
              training_steps_per_epoch=None,
              validation_steps_per_epoch=None,
              initial_learning_rate=.0001,
              learning_rate_drop=None,
              learning_rate_epochs=None,
              num_epochs=None,
              callbacks=['save_model', 'log'],
              **kwargs):
        """
        input_groups : list of strings, optional
            Specifies which named data groups (e.g. "ground_truth") enter which input
            data slot in your model.
        """

        self.create_data_generators(training_data_collection,
                                    validation_data_collection, input_groups,
                                    training_batch_size, validation_batch_size,
                                    training_steps_per_epoch,
                                    validation_steps_per_epoch)

        self.callbacks = get_callbacks(
            callbacks,
            output_model_filepath=output_model_filepath,
            data_collection=training_data_collection,
            model=self,
            batch_size=training_batch_size,
            backend='keras',
            **kwargs)

        try:
            if validation_data_collection is None:
                self.model.fit_generator(
                    generator=self.training_data_generator,
                    steps_per_epoch=self.training_steps_per_epoch,
                    epochs=num_epochs,
                    callbacks=self.callbacks)
            else:
                self.model.fit_generator(
                    generator=self.training_data_generator,
                    steps_per_epoch=self.training_steps_per_epoch,
                    epochs=num_epochs,
                    validation_data=self.validation_data_generator,
                    validation_steps=self.validation_steps_per_epoch,
                    callbacks=self.callbacks,
                    workers=0)
        except KeyboardInterrupt:
            for callback in self.callbacks:
                callback.on_train_end()
        except:
            raise

        return
예제 #2
0
    def fit_one_batch(self,
                      training_data_collection,
                      output_model_filepath=None,
                      input_groups=None,
                      output_directory=None,
                      callbacks=['save_model', 'log'],
                      training_batch_size=16,
                      training_steps_per_epoch=None,
                      num_epochs=None,
                      show_results=False,
                      **kwargs):

        one_batch_generator = self.keras_generator(
            training_data_collection.data_generator(
                perpetual=True,
                data_group_labels=input_groups,
                verbose=False,
                just_one_batch=True,
                batch_size=training_batch_size))

        self.callbacks = get_callbacks(
            callbacks,
            output_model_filepath=output_model_filepath,
            data_collection=training_data_collection,
            model=self,
            batch_size=training_batch_size,
            backend='keras',
            **kwargs)

        if training_steps_per_epoch is None:
            training_steps_per_epoch = training_data_collection.total_cases // training_batch_size + 1

        try:
            self.model.fit_generator(generator=one_batch_generator,
                                     steps_per_epoch=training_steps_per_epoch,
                                     epochs=num_epochs,
                                     callbacks=self.callbacks)
        except KeyboardInterrupt:
            for callback in self.callbacks:
                callback.on_train_end()
        except:
            raise

        one_batch = next(one_batch_generator)
        prediction = self.predict(one_batch[0])

        if show_results:
            check_data(output_data={
                self.input_data: one_batch[0],
                self.targets: one_batch[1],
                'prediction': prediction
            },
                       batch_size=training_batch_size)

        return
예제 #3
0
    def init_training(self, training_data_collection, kwargs):

        # Outputs
        add_parameter(self, kwargs, 'output_model_filepath')

        # Training Parameters
        add_parameter(self, kwargs, 'num_epochs', 100)
        add_parameter(self, kwargs, 'training_steps_per_epoch', 10)
        add_parameter(self, kwargs, 'training_batch_size', 16)
        add_parameter(self, kwargs, 'callbacks')

        self.callbacks = get_callbacks(backend='tensorflow',
                                       model=self,
                                       batch_size=self.training_batch_size,
                                       **kwargs)

        self.init_sess()
        self.build_tensorflow_model(self.training_batch_size)
        self.create_data_generators(
            training_data_collection,
            training_batch_size=self.training_batch_size,
            training_steps_per_epoch=self.training_steps_per_epoch)

        return
예제 #4
0
    def train(self,
              training_data_collection,
              validation_data_collection=None,
              output_model_filepath=None,
              input_groups=None,
              training_batch_size=32,
              validation_batch_size=32,
              training_steps_per_epoch=None,
              validation_steps_per_epoch=None,
              initial_learning_rate=.0001,
              learning_rate_drop=None,
              learning_rate_epochs=None,
              num_epochs=None,
              callbacks=['save_model', 'log'],
              **kwargs):
        """
        input_groups : list of strings, optional
            Specifies which named data groups (e.g. "ground_truth") enter which input
            data slot in your model.
        """

        # Todo: investigate call-backs more thoroughly.
        # Also, maybe something more general for the difference between training and validation.
        # Todo: list-checking for callbacks

        self.create_data_generators(training_data_collection,
                                    validation_data_collection, input_groups,
                                    training_batch_size, validation_batch_size,
                                    training_steps_per_epoch,
                                    validation_steps_per_epoch)

        if validation_data_collection is None:
            self.model.fit_generator(
                generator=self.training_data_generator,
                steps_per_epoch=self.training_steps_per_epoch,
                epochs=num_epochs,
                pickle_safe=True,
                callbacks=get_callbacks(
                    callbacks=callbacks,
                    output_model_filepath=output_model_filepath,
                    data_collection=training_data_collection,
                    batch_size=training_batch_size,
                    model=self,
                    backend='keras',
                    **kwargs))

        else:
            self.model.fit_generator(
                generator=self.training_data_generator,
                steps_per_epoch=self.training_steps_per_epoch,
                epochs=num_epochs,
                pickle_safe=True,
                validation_data=self.validation_data_generator,
                validation_steps=self.validation_steps_per_epoch,
                callbacks=get_callbacks(
                    callbacks,
                    output_model_filepath=output_model_filepath,
                    data_collection=training_data_collection,
                    model=self,
                    batch_size=training_batch_size,
                    backend='keras',
                    **kwargs))

        return