Exemple #1
0
class BaseModel(Model):
    def __init__(self, hyperparameters: Dict[str, Any]):
        super(BaseModel, self).__init__()
        self.hyperparameters = hyperparameters
        self.model = Sequential()

    def predict_name(self, code_block: str):
        raise NotImplementedError

    @staticmethod
    def from_file(path: str):
        """
        :arg path directory path to a file that contains, config, model and weights.
        :return a model populated from a file path.
        """
        return load_model('{}/model.h5'.format(path))

    def save(self, filepath, overwrite=True, include_optimizer=True) -> None:
        self.model.save_weights(filepath)
        model_type = type(self).__name__
        model_config_to_save = {
            "model_type": model_type,
            "hyperparameters": self.hyperparameters,
        }

        # Save hyperparameters
        with open('{path}/{name}/model_config.json'.format(
                path=filepath, name=model_type)) as fp:
            json.dump(model_config_to_save, fp)

        # Save the model architecture
        with open('{path}/{name}/model.json'.format(
                path=filepath, name=model_type)) as model_json:
            model_json.write(self.model.to_json())

        # Save the weight
        self.model.save_weights('{path}/{name}/model_weights.h5'.format(
            path=filepath, name=model_type))

        # Save the model completely
        self.model.save('{path}/{name}/model.h5'.format(path=filepath,
                                                        name=model_type))
Exemple #2
0
train_datagen = ImageDataGenerator(
    rescale=1. / 255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

test_datagen = ImageDataGenerator(rescale=1. / 255)

train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='binary')

validation_generator = test_datagen.flow_from_directory(
    validation_data_dir,
    target_size=(img_width, img_height),
    batch_size=batch_size,

    class_mode='binary')

model.fit(
    train_generator,
    steps_per_epoch=nb_train_samples // batch_size,
    epochs=epochs,
    validation_data=validation_generator,
    validation_steps=nb_validation_samples // batch_size)

model.save('potato.h5')
model.save_weights('potato_weights.h5')