class Model():
    valid_types = ['dnn', 'logistic', 'rf', 'svm', 'lr', 'elasticnet']

    def __init__(self,
                 model_type,
                 num_of_hidden_layers=2,
                 width_of_hidden_layers=256,
                 batch_size=32,
                 nb_epoch=100,
                 patience=10,
                 rf_n_estimators=100,
                 rf_criterion='gini',
                 rf_max_depth=8,
                 rf_min_samples_split=2,
                 kernel='rbf',
                 engine_id=None):

        assert model_type in self.valid_types

        self.id = str(uuid.uuid4())
        self.save_file_loc = f'{dir_loc}/text_analysis_results/models/{engine_id}_{self.id}_{model_type}_model'
        self.model_type = model_type
        self.num_of_hidden_layers = num_of_hidden_layers
        self.width_of_hidden_layers = width_of_hidden_layers
        self.batch_size = batch_size
        self.nb_epoch = nb_epoch
        self.patience = patience
        self.rf_n_estimators = rf_n_estimators
        self.rf_criterion = rf_criterion
        self.rf_max_depth = rf_max_depth
        self.rf_min_samples_split = rf_min_samples_split
        self.kernel = kernel
        self.metric = None

    def fit(self, x_train, x_val, y_train, y_val):

        if self.model_type == 'dnn':

            self.model = Sequential()
            self.model.add(
                Dense(self.width_of_hidden_layers,
                      input_dim=x_train.shape[1],
                      activation='relu'))
            for i in range(self.num_of_hidden_layers):
                self.model.add(
                    Dense(self.width_of_hidden_layers, activation='relu'))
            self.model.add(Dense(1, activation='sigmoid'))
            self.model.compile(loss='binary_crossentropy',
                               optimizer='adam',
                               metrics=['acc'])

            cb1 = callbacks.EarlyStopping(monitor='val_loss',
                                          min_delta=0,
                                          patience=self.patience,
                                          verbose=0,
                                          mode='auto')
            cb2 = callbacks.ModelCheckpoint(self.save_file_loc,
                                            monitor='val_loss',
                                            verbose=0,
                                            save_best_only=True,
                                            save_weights_only=False,
                                            mode='auto',
                                            period=1)
            self.model.fit(x_train,
                           y_train,
                           validation_data=(x_val, y_val),
                           callbacks=[cb1, cb2],
                           batch_size=self.batch_size,
                           nb_epoch=self.nb_epoch)
            self.model = load_model(self.save_file_loc)

        if self.model_type == 'rf':
            self.model = RandomForestClassifier(
                n_estimators=self.rf_n_estimators,
                criterion=self.rf_criterion,
                max_depth=self.rf_max_depth,
                min_samples_split=self.rf_min_samples_split)
            self.model.fit(x_train, y_train)
        if self.model_type == 'svm':
            self.model = SVC(kernel=self.kernel)
            self.model.fit(x_train, y_train)
        if self.model_type == 'lr':
            self.model = LogisticRegression()
            self.model.fit(x_train, y_train)
        if self.model_type == 'elasticnet':
            self.model = ElasticNet()
            self.model.fit(x_train, y_train)

        if self.model_type in ['lr', 'rf', 'elasticnet', 'svm', 'logistic']:
            with open(self.save_file_loc, 'wb') as f:
                pickle.dump(self.model, f)

        self.metric = self.evaluate(x_val, y_val)

    def predict(self, x):
        if hasattr(self.model, 'predict_proba'):
            preds = self.model.predict(x)[:, 1]
        else:
            preds = self.model.predict(x)
        return preds

    def evaluate(self, x_val, y_val):
        preds = np.rint(self.model.predict(x_val)).astype(int)
        metric = accuracy_score(preds, y_val)
        return metric