Beispiel #1
0
def main(cfg: DictConfig) -> None:
    seed_everything(cfg.general.random_seed)
    cfg.cwd_path = hydra.utils.to_absolute_path(cfg.cwd_path)
    print(OmegaConf.to_yaml(cfg, resolve=True))
    print("cwd_path = ", cfg.cwd_path)
    start_time = datetime.datetime.now()
    print(
        f"=================================================== train_ftwc.py - Start time: {start_time}\n{os.getcwd()}\n"
    )

    games = []
    for game in cfg.training.games:
        game = os.path.expanduser(game)
        if os.path.isdir(game):
            games += glob.glob(os.path.join(game, "*.ulx"))
        else:
            games.append(game)

    test_games = []
    for game in cfg.test.games:
        game = os.path.expanduser(game)
        if os.path.isdir(game):
            test_games += glob.glob(os.path.join(game, "*.ulx"))
        else:
            test_games.append(game)

    print("{} games found for training.".format(len(games)))
    # print(games)
    data = GamefileDataModule(cfg, gamefiles=games, testfiles=test_games)
    data.setup()
    # train(cfg, data.test_dataloader())

    agent = FtwcAgentLit(cfg)
    n_eval_subset = 1.0  # eval the full test set
    n_eval_subset = int(cfg.test.num_test_episodes /
                        data.test_dataloader().batch_size)
    trainer = Trainer(
        gpus=1,
        deterministic=True,
        distributed_backend='dp',
        # val_check_interval=100,
        # max_epochs=1,
        max_epochs=cfg.training.
        nb_epochs,  # 0 => does not call training_step() at all
        limit_val_batches=0,  # prevent validation_step() from getting called
        limit_test_batches=
        n_eval_subset  # eval a subset of test_set to speed things up while debugging
    )
    # os.mkdir("lightning_logs")
    # HACK! TEMPORARY: this should be part of train_dataloader
    agent.initialize_episode(games[:cfg.training.batch_size])

    # UGLY HACK so we can construct a transition and compute a loss
    agent.prepare_for_fake_replay()
    # END HACK
    trainer.fit(agent, data)
    # TODO: add callback to achieve:
    #     for epoch_no in range(1, cfg.training.nb_epochs + 1):
    #         ... from original train() method ...
    #         scores, steps = agent.run_episode(batch)
    #         print("Epoch: {:3d} | {:2.1f} pts | {:4.1f} steps".format(epoch_no, score, steps))

    trainer.test(agent, datamodule=data)
    finish_time = datetime.datetime.now()
    print(
        f"=================================================== train_ftwc.py - Finished : {finish_time} -- elapsed: {finish_time-start_time}"
    )