def main(args): """Evaluate model and save the results. """ # Read the checkpoint and train args. print("Loading checkpoint: {}".format(args["checkpoint"])) checkpoint = torch.load(args["checkpoint"], map_location=torch.device("cpu")) saved_args = checkpoint["args"] saved_args.update(args) support.pretty_print_dict(saved_args) # Dataloader for evaluation. dataloader_args = { "single_pass": True, "shuffle": False, "data_read_path": args["eval_data_path"], "get_retrieval_candidates": True } dataloader_args.update(saved_args) val_loader = loaders.DataloaderSIMMC(dataloader_args) saved_args.update(val_loader.get_data_related_arguments()) # Model. wizard = models.Assistant(saved_args) # Load the checkpoint. wizard.load_state_dict(checkpoint["model_state"]) # Evaluate the SIMMC model. eval_dict, eval_outputs = evaluate_agent(wizard, val_loader, saved_args) save_path = saved_args["checkpoint"].replace(".tar", "_eval.json") print("Saving results: {}".format(save_path)) with open(save_path, "w") as file_id: json.dump(eval_dict, file_id)
import options import eval_simmc_agent as evaluation from tools import support # Arguments. args = options.read_command_line() # Dataloader. dataloader_args = { "single_pass": False, "shuffle": True, "data_read_path": args["train_data_path"], "get_retrieval_candidates": False } dataloader_args.update(args) train_loader = loaders.DataloaderSIMMC(dataloader_args) args.update(train_loader.get_data_related_arguments()) # Initiate the loader for val (DEV) data split. if args["eval_data_path"]: dataloader_args = { "single_pass": True, "shuffle": False, "data_read_path": args["eval_data_path"], "get_retrieval_candidates": args["retrieval_evaluation"] } dataloader_args.update(args) val_loader = loaders.DataloaderSIMMC(dataloader_args) else: val_loader = None # Model.