def _test(args): assert args.load test_fname = args.eval_data data_gens = get_datasets([(test_fname, 'test', args.goal)], args) model = models.Model(args, constant.ANSWER_NUM_DICT[args.goal]) model.cuda() model.eval() load_model(args.reload_model_name, constant.EXP_ROOT, args.model_id, model) for name, dataset in [(test_fname, data_gens[0])]: print('Processing... ' + name) total_gold_pred = [] total_annot_ids = [] total_probs = [] total_ys = [] for batch_num, batch in enumerate(dataset): eval_batch, annot_ids = to_torch(batch) loss, masked_logits, raw_logits, mask = model( eval_batch, args.goal) print(mask) # output_index = get_output_index(masked_logits) # output_prob = masked_logits.data.cpu().clone().numpy() # y = eval_batch['y'].data.cpu().clone().numpy() # gold_pred = get_gold_pred_str(output_index, y, args.goal) # total_probs.extend(output_prob) # total_ys.extend(y) # total_gold_pred.extend(gold_pred) # total_annot_ids.extend(annot_ids) # raw output_index = get_output_index(raw_logits) output_prob = raw_logits.data.cpu().clone().numpy() y = eval_batch['y'].data.cpu().clone().numpy() gold_pred = get_gold_pred_str(output_index, y, args.goal) total_probs.extend(output_prob) total_ys.extend(y) total_gold_pred.extend(gold_pred) total_annot_ids.extend(annot_ids) mrr_val = mrr(total_probs, total_ys) print('mrr_value: ', mrr_val) pickle.dump({ 'gold_id_array': total_ys, 'pred_dist': total_probs }, open('./{0:s}.p'.format(args.reload_model_name), "wb")) with open('./{0:s}.json'.format(args.reload_model_name), 'w') as f_out: output_dict = {} for a_id, (gold, pred) in zip(total_annot_ids, total_gold_pred): output_dict[a_id] = {"gold": gold, "pred": pred} json.dump(output_dict, f_out, indent=2) eval_str = get_eval_string(total_gold_pred) print(eval_str) logging.info('processing: ' + name) logging.info(eval_str)
def get_mrr(pred_fname): dicts = pickle.load(open(pred_fname, "rb")) mrr_value = mrr(dicts['pred_dist'], dicts['gold_id_array']) return mrr_value