def preprocess_step_0(countries_metadata, country_index): # for a single country dbn = get_hs_db_name(countries_metadata, country_index) raw_db_fn = RAW_HS_DATABASE_DIR + dbn + RAW_HS_DB_FILENAME_EXTENSION new_db_fn = HS_DATABASE_DIR + dbn + HS_DB_FILENAME_EXTENSION convert_raw_data_to_internal_db_0(raw_db_fn, new_db_fn, SAVE_HS_DATABASES_TO_EXCEL)
def test_0(): countries_metadata = get_metadata() nm = get_hs_db_name(countries_metadata, 161) ts_dir = TRAINING_SETS_MXNET_DIR + nm + '/' feat = load_features(ts_dir, 0) net = get_nn_for_training(1024) print('TEST: ' + str(net(feat).shape))
def training_tool_0(countries_metadata, country_index): nm = get_hs_db_name(countries_metadata, country_index) ts_dir = TRAINING_SETS_MXNET_DIR + nm + '/' db_fn = HS_DATABASE_DIR + nm + HS_DB_FILENAME_EXTENSION n = get_number_of_records(db_fn) net = get_nn_for_training(n) net = train_0(net, ts_dir) net_fn = NETWORK_PARAMETERS_DIR + nm + NETWORK_PARAMETERS_FILENAME_EXTENSION save_params(net, net_fn)
def hs_search(countries_metadata, query, country_index, return_results=False): dbn = get_hs_db_name(countries_metadata, country_index) db_tt = get_hs_db_characters_type(countries_metadata, country_index) db_fn = HS_DATABASE_DIR + dbn + HS_DB_FILENAME_EXTENSION n = get_number_of_records(db_fn) params_fn = NETWORK_PARAMETERS_DIR + dbn + NETWORK_PARAMETERS_FILENAME_EXTENSION net = get_nn_for_prediction(n, params_fn) pred = partial(predictor_0, net) res_fn = RESULTS_DIR + dbn + RESULTS_FILENAME_EXTENSION data = load_dataframe_from_pickle_0(db_fn) res = hs_search_0(data, pred, db_tt, query, res_fn, return_results) if return_results: return res
def search_0(user_window): text = user_window.query_edit.text() if text == '': user_window.results_label.setText( '<font color=red>Please enter the description.</font>') else: ci = user_window.countries_combo.currentIndex() hs_search(user_window.countries_metadata, text, ci) cid = get_hs_db_name(user_window.countries_metadata, ci) results = file_io_json.load(RESULTS_DIR + cid + RESULTS_FILENAME_EXTENSION) (nr, nc) = results.shape user_window.table.setRowCount(nr) user_window.table.setColumnCount(nc) cols = results.columns.values if nr > 1: user_window.results_label.setText( '<font color=red>There is more than one result. ' 'Please provide more details for disambiguation.</font>') elif nr == 0: user_window.results_label.setText( '<font color=red>No results.</font>') else: user_window.results_label.setText('The result:') for j in range(nc): head_item = QTableWidgetItem(cols[j]) user_window.table.setHorizontalHeaderItem(j, head_item) for i in range(nr): for j in range(nc): # r = unicode(results.iat[i, j]) for Python 2 only if cols[j] == HS_RESULTS_RELEVANCE_LABEL: r = str("%.2f" % results.iat[i, j]) else: r = str(results.iat[i, j]) item = QTableWidgetItem(r) user_window.table.setItem(i, j, item) user_window.table.setColumnWidth(0, WINDOW_WIDTH - 250) for j in range(1, nc): user_window.table.resizeColumnToContents(j) for i in range(nr): user_window.table.resizeRowToContents(i)
def generate_training_set(countries_metadata, country_index): batch_no = 0 example_no = 0 text_type = get_hs_db_characters_type(countries_metadata, country_index) dbn = get_hs_db_name(countries_metadata, country_index) ts_dir = TRAINING_SETS_MXNET_DIR + dbn + '/' refresh_dir(ts_dir) db_fn = HS_DATABASE_DIR + dbn + HS_DB_FILENAME_EXTENSION n = get_number_of_records(db_fn) for k in range(n): desc = get_hs_description(db_fn, k) desc = normalize_1(text_type, desc) desc_list = replicate_hs_desc_0(desc) for d in desc_list: r = get_text_representation(d) save_features(ts_dir, batch_no, r) save_target(ts_dir, batch_no, k) batch_no += 1 example_no += 1 save_int(ts_dir + NUM_BATCHES_FILENAME, batch_no) save_int(ts_dir + NUM_TRAINING_EXAMPLES_FILENAME, example_no) if VERBOSITY > 0: print('Training set has been generated from "' + dbn + '" database.')