def test(config_file, meta_data_file, id_map, dataToken, batch_data_dir, max_doc_length=30, model_name=None, restore_path=None, no_doc_index=False): np.random.seed(RANDOM_SEED_NP) data = DataDUELoader(meta_data_file=meta_data_file, batch_data_dir=batch_data_dir, id_map=id_map, dataToken=dataToken, max_doc_length=max_doc_length, no_doc_index=no_doc_index) model_spec = json_reader(config_file) model = FM(feature_shape=(0 if no_doc_index else data.D) + data.U + data.V + 1, feature_dim=(0 if no_doc_index else 1) + 1 + max_doc_length, label_dim=data.E, model_spec=model_spec, model_name=model_name) model.initialization() def performance(model_local, data_local): preds = model_local.predict(data_generator=data_local) labels = [] for data_batched in data_local.generate( batch_size=model_spec["batch_size"], random_shuffle=False): labels.append(data_batched["label"]) labels = np.concatenate(labels, axis=0) # one-hot to index # trues = np.argmax(labels, axis=-1) perf = evaluate(preds=preds, trues=trues) return perf if restore_path is not None: if not isinstance(restore_path, list): restore_paths = [restore_path] else: restore_paths = restore_path for restore_path in restore_paths: model.restore(restore_path) perf = performance(model_local=model, data_local=data) print("ckpt_path: %s" % restore_path) print("performance: %s" % str(perf)) else: perf = performance(model_local=model, data_local=data) print("random initialization") print("performance: %s" % str(perf))
def train(config_file, meta_data_file, id_map, dataToken, batch_data_dir_train, batch_data_dir_valid=None, max_doc_length=30, model_name=None, restore_path=None, no_doc_index=False): np.random.seed(RANDOM_SEED_NP) data_train = DataDUELoader(meta_data_file=meta_data_file, batch_data_dir=batch_data_dir_train, id_map=id_map, dataToken=dataToken, max_doc_length=max_doc_length, no_doc_index=no_doc_index) if batch_data_dir_valid is not None: data_valid = DataDUELoader(meta_data_file=meta_data_file, batch_data_dir=batch_data_dir_valid, id_map=id_map, dataToken=dataToken, max_doc_length=max_doc_length, no_doc_index=no_doc_index) else: data_valid = None model_spec = json_reader(config_file) model = FM(feature_shape=(0 if no_doc_index else data_train.D) + data_train.U + data_train.V + 1, feature_dim=(0 if no_doc_index else 1) + 1 + max_doc_length, label_dim=data_train.E, model_spec=model_spec, model_name=model_name) model.initialization() if restore_path is not None: model.restore(restore_path) # train # results = model.train(data_generator=data_train, data_generator_valid=data_valid) print("train_results: %s" % str(results)) best_epoch = read(directory="../summary/" + model.model_name, main_indicator="epoch_losses_valid_00")[0] print("best_epoch by validation loss: %d" % best_epoch)