예제 #1
0
        "avg_steps": total_steps_ran / steps,
    }

    sum_loss = 0.0
    steps = 0.0

    lol.eval()

    # Save epoch snapshot using some validation image
    model_path = os.path.join(args.output, args.name, 'last.pt')
    screenshot_path = os.path.join(args.output, args.name, "screenshots",
                                   str(epoch) + ".png")
    create_folders(screenshot_path)
    torch.save(lol.state_dict(), model_path)
    time.sleep(1)
    paint_model_run(model_path, validation_loader, destination=screenshot_path)

    for index, x in enumerate(test_dataloader):
        x = x[0]
        img = Variable(x['img'].type(dtype), requires_grad=False)[None, ...]
        ground_truth = x["steps"]

        # Iterates over the line until the end

        sol = ground_truth[0].cuda()
        predicted_steps, length, _ = lol(img,
                                         sol,
                                         ground_truth,
                                         max_steps=len(ground_truth),
                                         disturb_sol=False)
예제 #2
0
parser.add_argument("--output", default="scripts/original/snapshots/training")
parser.add_argument("--model",
                    default="scripts/new/snapshots/training2/lol-last.pt")
args = parser.parse_args()

data_folder = os.getenv("DATA_FOLDER") if os.getenv("DATA_FOLDER") else "data"
target_folder = os.path.join(data_folder, "sfrs", args.dataset)
pages_folder = os.path.join(target_folder, "pages")
char_set_path = os.path.join(pages_folder, "character_set.json")

test_set_list_path = os.path.join(pages_folder, "validation.json")
test_set_list = load_file_list_direct(test_set_list_path)
test_dataset = LolDataset(test_set_list[0:1])
test_dataloader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=0,
                             collate_fn=lol_dataset.collate)

count = 0

while True:
    for t in ["training", "training2"]:
        model_path = "scripts/new/snapshots/" + t + "/last.pt"
        paint_model_run(model_path,
                        test_dataloader,
                        destination=os.path.join("screenshots", t,
                                                 str(count) + ".png"))
    time.sleep(30)
    count += 1