Ejemplo n.º 1
0
def main(args):
    """Evaluate model and save the results.
    """
    # Read the checkpoint and train args.
    print("Loading checkpoint: {}".format(args["checkpoint"]))
    checkpoint = torch.load(args["checkpoint"], map_location=torch.device("cpu"))
    saved_args = checkpoint["args"]
    saved_args.update(args)
    support.pretty_print_dict(saved_args)

    # Dataloader for evaluation.
    dataloader_args = {
        "single_pass": True,
        "shuffle": False,
        "data_read_path": args["eval_data_path"],
        "get_retrieval_candidates": True
    }
    dataloader_args.update(saved_args)
    val_loader = loaders.DataloaderSIMMC(dataloader_args)
    saved_args.update(val_loader.get_data_related_arguments())

    # Model.
    wizard = models.Assistant(saved_args)
    # Load the checkpoint.
    wizard.load_state_dict(checkpoint["model_state"])

    # Evaluate the SIMMC model.
    eval_dict, eval_outputs = evaluate_agent(wizard, val_loader, saved_args)
    save_path = saved_args["checkpoint"].replace(".tar", "_eval.json")
    print("Saving results: {}".format(save_path))
    with open(save_path, "w") as file_id:
        json.dump(eval_dict, file_id)
Ejemplo n.º 2
0
import options
import eval_simmc_agent as evaluation
from tools import support

# Arguments.
args = options.read_command_line()

# Dataloader.
dataloader_args = {
    "single_pass": False,
    "shuffle": True,
    "data_read_path": args["train_data_path"],
    "get_retrieval_candidates": False
}
dataloader_args.update(args)
train_loader = loaders.DataloaderSIMMC(dataloader_args)
args.update(train_loader.get_data_related_arguments())
# Initiate the loader for val (DEV) data split.
if args["eval_data_path"]:
    dataloader_args = {
        "single_pass": True,
        "shuffle": False,
        "data_read_path": args["eval_data_path"],
        "get_retrieval_candidates": args["retrieval_evaluation"]
    }
    dataloader_args.update(args)
    val_loader = loaders.DataloaderSIMMC(dataloader_args)
else:
    val_loader = None

# Model.