print("Using model ({}) for illustration...".format(model_save_path)) else: model_save_path = sys.argv[1] # config logger logger = logging.getLogger('__main__') util.init_logger() logger.info("Evaluation of trained model on MultiNLI. Model ({})".format( model_save_path)) # init supervisor spv = Supervisor(config) # load trained embedding spv.load_trained_emb() # load trained model spv.load_checkpoint(f_path=model_save_path) eval_results = [] # Evaluate on each genre of MultiSNLI for genre in MultiGenre: logger.info("##### GENRE: {} #####".format(genre.upper())) spv.load_data(TaskName.MNLI, genre) spv.get_dataloader() acc, _ = spv.eval_model(spv.loaders[LoaderType.VAL]) eval_results.append({'Genre': genre, "valAcc": acc}) # logger.info("(model){} (genre){} (val_acc){}".format(spv.config['model_name'], genre, acc)) print("\n", pd.DataFrame.from_records(eval_results), "\n") logger.info( "Model from {}\nEvaluation on MultiNLI is done!\n".format(model_save_path))
if len(sys.argv) < 2: print("=== ! ===\nNo input for trained model path!") model_save_path = 'results_1540947201/checkpoints/encRNNhid50lea0.01.tar' print("Using model ({}) for illustration...".format(model_save_path)) else: model_save_path = sys.argv[1] # init supervisor spv = Supervisor(config) spv.load_trained_emb() spv.load_data() spv.get_dataloader() # ============= Load checkpoint ============ # load trained model spv.load_checkpoint(f_path='./results_1540933016/checkpoints/demo_rnn_hidd200_drop0.2_if_eTrue.tar') acc, loss = spv.eval_model(spv.loaders[LoaderType.VAL]) print("Load model evaluated on val_loader: (valAcc){} (valLoss){}".format(acc, loss)) # ============= Sample Analysis ============ # Find at least 3 samples from correct and # incorrect classification respectively spv.model.eval() corr_count = 0 incorr_count = 0 for prem, hypo, p_len, h_len, labels in spv.loaders[LoaderType.VAL]: outputs = F.softmax(spv.model(prem, hypo, p_len, h_len), dim=1) predicted = outputs.max(1, keepdim=True)[1] eq = predicted.eq(labels.view_as(predicted)).numpy() for i in range(len(eq)): if eq[i] == 1: # correct