示例#1
0
def run_catalyst(irunner: dl.IRunner,
                 idx: int,
                 device: str = "cuda",
                 num_epochs: int = 10):
    utils.set_global_seed(idx)
    loader = irunner.get_loaders()["train"]
    model = irunner.get_model().to(device)
    criterion = irunner.get_criterion()
    optimizer = irunner.get_optimizer(model)

    runner = dl.SupervisedRunner()
    runner.train(
        engine=dl.GPUEngine() if device == "cuda" else dl.CPUEngine(),
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        loaders={"train": loader},
        num_epochs=num_epochs,
        verbose=False,
        callbacks=[
            dl.AccuracyCallback(
                input_key=runner._output_key,
                target_key=runner._target_key,
                topk=(1, ),
            )
        ],
    )

    return (
        runner.epoch_metrics["train"]["accuracy01"],
        runner.epoch_metrics["train"]["loss"],
        _get_used_memory(),
    )
示例#2
0
def test_run_on_amp():
    train_experiment(dl.GPUEngine(fp16=True))
示例#3
0
def test_run_on_torch_cuda0():
    train_experiment(dl.GPUEngine())
示例#4
0
 def get_engine(self):
     return dl.GPUEngine()
示例#5
0
        n_step=1,
        gamma=gamma,
        history_len=1,
    )

    network, target_network = get_network(env), get_network(env)
    set_requires_grad(target_network, requires_grad=False)
    models = nn.ModuleDict({"origin": network, "target": target_network})
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(network.parameters(), lr=lr)
    loaders = {"train_game": DataLoader(replay_buffer, batch_size=batch_size)}

    runner = CustomRunner(gamma=gamma, tau=tau, tau_period=tau_period)
    runner.train(
        # for simplicity reasons, let's run everything on single gpu
        engine=dl.GPUEngine(),
        model=models,
        criterion=criterion,
        optimizer=optimizer,
        loaders=loaders,
        logdir="./logs_dqn",
        num_epochs=50,
        verbose=True,
        valid_loader="_epoch_",
        valid_metric="reward",
        minimize_valid_metric=False,
        load_best_on_end=True,
        callbacks=[
            GameCallback(
                sampler_fn=Sampler,
                env=env,