示例#1
0
def main(args):
    """Evaluate model and save the results.
    """
    # Read the checkpoint and train args.
    print("Loading checkpoint: {}".format(args["checkpoint"]))
    checkpoint = torch.load(args["checkpoint"],
                            map_location=torch.device("cpu"))
    saved_args = checkpoint["args"]
    saved_args.update(args)
    support.pretty_print_dict(saved_args)

    # Dataloader for evaluation.
    dataloader_args = {
        "single_pass": True,
        "shuffle": False,
        "data_read_path": args["eval_data_path"],
        "get_retrieval_candidates": True
    }
    dataloader_args.update(saved_args)
    val_loader = loaders.DataloaderSIMMC(dataloader_args)
    saved_args.update(val_loader.get_data_related_arguments())

    # Model.
    wizard = models.Assistant(saved_args)
    # Load the checkpoint.
    wizard.load_state_dict(checkpoint["model_state"])

    # Evaluate the SIMMC model.
    eval_dict, eval_outputs = evaluate_agent(wizard, val_loader, saved_args)
    save_path = saved_args["checkpoint"].replace(".tar", "_eval.json")
    print("Saving results: {}".format(save_path))
    with open(save_path, "w") as file_id:
        json.dump(eval_dict, file_id)
args.update(train_loader.get_data_related_arguments())
# Initiate the loader for val (DEV) data split.
if args["eval_data_path"]:
    dataloader_args = {
        "single_pass": True,
        "shuffle": False,
        "data_read_path": args["eval_data_path"],
        "get_retrieval_candidates": args["retrieval_evaluation"]
    }
    dataloader_args.update(args)
    val_loader = loaders.DataloaderSIMMC(dataloader_args)
else:
    val_loader = None

# Model.
wizard = models.Assistant(args)
wizard.train()
if args["encoder"] == "tf_idf":
    wizard.encoder.IDF.data = train_loader._ship_helper(train_loader.IDF)

# Optimizer.
optimizer = torch.optim.Adam(wizard.parameters(), args["learning_rate"])

# Training iterations.
smoother = support.ExponentialSmoothing()
num_iters_per_epoch = train_loader.num_instances / args["batch_size"]
print("Number of iterations per epoch: {:.2f}".format(num_iters_per_epoch))
eval_dict = {}
best_epoch = -1

# first_batch = None