def train_ensemble_model(ensemble_models, model_name, variation, dev_data, train_data=None, test_data=None, binary_threshold=0.5, checkpoint_dir=None, overwrite=False, log_error=False, save_log=True, **kwargs): config = ModelConfig() config.binary_threshold = binary_threshold if checkpoint_dir is not None: config.checkpoint_dir = checkpoint_dir if not path.exists(config.checkpoint_dir): os.makedirs(config.checkpoint_dir) config.exp_name = '{}_{}_ensemble_with_{}'.format(variation, model_name, ensemble_models) train_log = { 'exp_name': config.exp_name, 'binary_threshold': binary_threshold } print('Logging Info - Ensemble Experiment: ', config.exp_name) if model_name == 'svm': model = SVMModel(config, **kwargs) elif model_name == 'lr': model = LRModel(config, **kwargs) elif model_name == 'sgd': model = SGDModel(config, **kwargs) elif model_name == 'gnb': model = GaussianNBModel(config, **kwargs) elif model_name == 'mnb': model = MultinomialNBModel(config, **kwargs) elif model_name == 'bnb': model = BernoulliNBModel(config, **kwargs) elif model_name == 'rf': model = RandomForestModel(config, **kwargs) elif model_name == 'gbdt': model = GBDTModel(config, **kwargs) elif model_name == 'xgboost': model = XGBoostModel(config, **kwargs) elif model_name == 'lda': model = LDAModel(config, **kwargs) else: raise ValueError('Model Name Not Understood : {}'.format(model_name)) model_save_path = path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name)) if train_data is not None and (not path.exists(model_save_path) or overwrite): model.train(train_data) model.load_best_model() print('Logging Info - Evaluate over valid data:') valid_acc, valid_f1, valid_macro_f1, valid_p, valid_r = model.evaluate( dev_data) train_log['valid_acc'] = valid_acc train_log['valid_f1'] = valid_f1 train_log['valid_macro_f1'] = valid_macro_f1 train_log['valid_p'] = valid_p train_log['valid_r'] = valid_r train_log['time_stamp'] = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) if log_error: error_indexes, error_pred_probas = model.error_analyze(dev_data) dev_text_input = load_processed_text_data(variation, 'dev') for error_index, error_pred_prob in zip(error_indexes, error_pred_probas): train_log['error_%d' % error_index] = '{},{},{},{}'.format( error_index, dev_text_input['sentence'][error_index], dev_text_input['label'][error_index], error_pred_prob) if save_log: write_log(format_filename(LOG_DIR, PERFORMANCE_LOG_TEMPLATE, variation=variation), log=train_log, mode='a') if test_data is not None: test_predictions = model.predict(test_data) writer_predict( format_filename(PREDICT_DIR, config.exp_name + '.labels'), test_predictions) return valid_acc, valid_f1, valid_macro_f1, valid_p, valid_r
def train_dl_model(variation, input_level, word_embed_type, word_embed_trainable, batch_size, learning_rate, optimizer_type, model_name, binary_threshold=0.5, checkpoint_dir=None, overwrite=False, log_error=False, save_log=True, **kwargs): config = ModelConfig() config.variation = variation config.input_level = input_level if '_aug' in variation: config.max_len = { 'word': config.aug_word_max_len, 'char': config.aug_char_max_len } config.word_embed_type = word_embed_type config.word_embed_trainable = word_embed_trainable config.word_embeddings = np.load( format_filename(PROCESSED_DATA_DIR, EMBEDDING_MATRIX_TEMPLATE, variation=variation, type=word_embed_type)) config.batch_size = batch_size config.learning_rate = learning_rate config.optimizer = get_optimizer(optimizer_type, learning_rate) config.binary_threshold = binary_threshold if checkpoint_dir is not None: config.checkpoint_dir = checkpoint_dir if not os.path.exists(config.checkpoint_dir): os.makedirs(config.checkpoint_dir) config.exp_name = '{}_{}_{}_{}_{}'.format( variation, model_name, input_level, word_embed_type, 'tune' if word_embed_trainable else 'fix') train_log = { 'exp_name': config.exp_name, 'batch_size': batch_size, 'optimizer': optimizer_type, 'learning_rate': learning_rate, 'binary_threshold': binary_threshold } print('Logging Info - Experiment: ', config.exp_name) if model_name == 'bilstm': model = BiLSTM(config, **kwargs) elif model_name == 'cnnrnn': model = CNNRNN(config, **kwargs) elif model_name == 'dcnn': model = DCNN(config, **kwargs) elif model_name == 'dpcnn': model = DPCNN(config, **kwargs) elif model_name == 'han': model = HAN(config, **kwargs) elif model_name == 'multicnn': model = MultiTextCNN(config, **kwargs) elif model_name == 'rcnn': model = RCNN(config, **kwargs) elif model_name == 'rnncnn': model = RNNCNN(config, **kwargs) elif model_name == 'cnn': model = TextCNN(config, **kwargs) elif model_name == 'vdcnn': model = VDCNN(config, **kwargs) else: raise ValueError('Model Name Not Understood : {}'.format(model_name)) train_input = load_processed_data(variation, input_level, 'train') dev_input = load_processed_data(variation, input_level, 'dev') test_input = load_processed_data(variation, input_level, 'test') model_save_path = path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name)) if not path.exists(model_save_path) or overwrite: start_time = time.time() model.train(train_input, dev_input) elapsed_time = time.time() - start_time print('Logging Info - Training time: %s', time.strftime("%H:%M:%S", time.gmtime(elapsed_time))) train_log['train_time'] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time)) # load the best model model.load_best_model() print('Logging Info - Evaluate over valid data:') valid_acc, valid_f1, valid_macro_f1, valid_p, valid_r = model.evaluate( dev_input) train_log['valid_acc'] = valid_acc train_log['valid_f1'] = valid_f1 train_log['valid_macro_f1'] = valid_macro_f1 train_log['valid_p'] = valid_p train_log['valid_r'] = valid_r train_log['time_stamp'] = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) if log_error: error_indexes, error_pred_probas = model.error_analyze(dev_input) dev_text_input = load_processed_text_data(variation, 'dev') for error_index, error_pred_prob in zip(error_indexes, error_pred_probas): train_log['error_%d' % error_index] = '{},{},{},{}'.format( error_index, dev_text_input['sentence'][error_index], dev_text_input['label'][error_index], error_pred_prob) if save_log: write_log(format_filename(LOG_DIR, PERFORMANCE_LOG_TEMPLATE, variation=variation), log=train_log, mode='a') if test_input is not None: test_predictions = model.predict(test_input) writer_predict( format_filename(PREDICT_DIR, config.exp_name + '.labels'), test_predictions) return valid_acc, valid_f1, valid_macro_f1, valid_p, valid_r