def inference(): dataset = DKNDataset('data/test/behaviors_cleaned.tsv', 'data/test/news_with_entity.tsv') print(f"Load inference dataset with size {len(dataset)}.") dataloader = DataLoader(dataset, batch_size=Config.batch_size, shuffle=False, num_workers=Config.num_workers, drop_last=False) # Load trained embedding file # num_entity_tokens, entity_embedding_dim entity_embedding = np.load('data/train/entity_embedding.npy') context_embedding = np.load('data/train/entity_embedding.npy') # TODO dkn = DKN(Config, entity_embedding, context_embedding).to(device) checkpoint_path = latest_checkpoint('./checkpoint') print(f"Load saved parameters in {checkpoint_path}") checkpoint = torch.load(checkpoint_path) dkn.load_state_dict(checkpoint['model_state_dict']) dkn.eval() y_pred = [] y = [] count = 0 with tqdm(total=len(dataloader), desc="Inferering") as pbar: for minibatch in dataloader: y_pred.extend( dkn(minibatch["candidate_news"], minibatch["clicked_news"]).tolist()) y.extend(minibatch["clicked"].float().tolist()) pbar.update(1) count += 1 if count == 500: break y_pred = iter(y_pred) y = iter(y) # For locating and order validating truth_file = open('./data/test/truth.json', 'r') # For writing inference results submission_answer_file = open('./data/test/answer.json', 'w') try: for line in truth_file.readlines(): user_truth = json.loads(line) user_inference = copy.deepcopy(user_truth) for k in user_truth['impression'].keys(): assert next(y) == user_truth['impression'][k] user_inference['impression'][k] = next(y_pred) submission_answer_file.write(json.dumps(user_inference) + '\n') except StopIteration: print( 'Warning: Behaviors not fully inferenced. You can still run evaluate.py, but the evaluation result would be inaccurate.' ) submission_answer_file.close()
def get_models_and_configs(args): ensemble_model_names = args.models assert type(ensemble_model_names) == type([]) models_configs = {model_name: [] for model_name in ensemble_model_names} for idx, (model_name, model_list) in enumerate(models_configs.items()): Config_cls = getattr(importlib.import_module('config'), f"{model_name}Config") config = Config_cls() config.datasize = args.datasize config.model_name = model_name config.configure_datasize() config.device_str = args.device checkpoint_path = latest_checkpoint(os.path.join(config.checkpoint_dir, config.datasize, config.model_name)) if checkpoint_path is None: print('No checkpoint file found!') exit() print(f"Load saved parameters in {checkpoint_path}") checkpoint = torch.load(checkpoint_path) Model_cls = getattr(importlib.import_module(f"model.{model_name}"), model_name) model = Model_cls(config).to(config.device_str) model.load_state_dict(checkpoint['model_state_dict']) model.eval() model_list.append(model) model_list.append(config) return models_configs
def evaluation(): print('Using device:', device) print(f'Evaluating model {model_name}') # Don't need to load pretrained word/entity/context embedding # since it will be loaded from checkpoint later model = Model(config).to(device) from train import latest_checkpoint # Avoid circular imports checkpoint_path = latest_checkpoint(path.join('./checkpoint', model_name)) if checkpoint_path is None: print('No checkpoint file found!') exit() print(f"Load saved parameters in {checkpoint_path}") checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) model.eval() auc, mrr, ndcg5, ndcg10 = evaluate(model, './data/test') print( f'AUC: {auc:.4f}\nMRR: {mrr:.4f}\nnDCG@5: {ndcg5:.4f}\nnDCG@10: {ndcg10:.4f}' )
print(f'saved {txt_path}') return if __name__ == '__main__': print('Using device:', device) print(f'Evaluating model {model_name}') # Don't need to load pretrained word/entity/context embedding # since it will be loaded from checkpoint later model = Model(config).to(device) from train import latest_checkpoint # Avoid circular imports checkpoint_dir = os.path.join( '../checkpoint', model_name, 'batch_size' + str(config.batch_size) + '_num' + str(config.num_clicked_news_a_user)) checkpoint_path = latest_checkpoint(checkpoint_dir) if checkpoint_path is None: print('No checkpoint file found!') exit() print(f"Load saved parameters in {checkpoint_path}") checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) model.eval() evaluate(model, '../data/real_test', True, '../data/real_test/prediction.txt') print('zippppp') f = zipfile.ZipFile('../data/real_test/prediction.zip', 'w', zipfile.ZIP_DEFLATED)
f"{minibatch['impression_id'][0]} {str(list(value2rank(impression).values())).replace(' ','')}\n" ) pbar.update(1) if generate_txt: answer_file.close() return np.mean(aucs), np.mean(mrrs), np.mean(ndcg5s), np.mean(ndcg10s) if __name__ == '__main__': print('Using device:', device) print(f'Evaluating model {model_name}') # Don't need to load pretrained word/entity/context embedding # since it will be loaded from checkpoint later model = Model(config).to(device) from train import latest_checkpoint # Avoid circular imports checkpoint_path = latest_checkpoint( os.path.join('../checkpoint', model_name)) if checkpoint_path is None: print('No checkpoint file found!') exit() print(f"Load saved parameters in {checkpoint_path}") checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) model.eval() auc, mrr, ndcg5, ndcg10 = evaluate(model, '../data/test', True, '../data/test/prediction.txt') print( f'AUC: {auc:.4f}\nMRR: {mrr:.4f}\nnDCG@5: {ndcg5:.4f}\nnDCG@10: {ndcg10:.4f}' )
# batch_size y = minibatch['label'].float() loss = criterion(y_pred, y.to(device)) loss_full.append(loss.item()) y_pred_full.extend((torch.sigmoid(y_pred) > 0.5).int().tolist()) y_full.extend(y.tolist()) pbar.update(1) return np.mean(loss_full), classification_report(y_full, y_pred_full, output_dict=True, zero_division=0) if __name__ == '__main__': test_dataset = MyDataset('data/test/news_parsed.tsv') # Don't need to load pretrained word embedding # since it will be loaded from checkpoint later model = Model(Config).to(device) from train import latest_checkpoint # Avoid circular imports checkpoint_path = latest_checkpoint('./checkpoint') if checkpoint_path is None: print('No checkpoint file found!') exit() checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['model_state_dict']) model.eval() loss, report = evaluate(model, test_dataset) print(report['weighted avg'])