コード例 #1
0
def run_catalyst(irunner: dl.IRunner,
                 idx: int,
                 device: str = "cuda",
                 num_epochs: int = 10):
    utils.set_global_seed(idx)
    loader = irunner.get_loaders()["train"]
    model = irunner.get_model().to(device)
    criterion = irunner.get_criterion()
    optimizer = irunner.get_optimizer(model)

    runner = dl.SupervisedRunner()
    runner.train(
        engine=dl.GPUEngine() if device == "cuda" else dl.CPUEngine(),
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        loaders={"train": loader},
        num_epochs=num_epochs,
        verbose=False,
        callbacks=[
            dl.AccuracyCallback(
                input_key=runner._output_key,
                target_key=runner._target_key,
                topk=(1, ),
            )
        ],
    )

    return (
        runner.epoch_metrics["train"]["accuracy01"],
        runner.epoch_metrics["train"]["loss"],
        _get_used_memory(),
    )
コード例 #2
0
def test_run_on_cpu():
    train_experiment(dl.CPUEngine())
コード例 #3
0
ファイル: ddpg.py プロジェクト: catalyst-team/catalyst
    criterion = torch.nn.MSELoss()
    optimizer = {
        "actor": torch.optim.Adam(actor.parameters(), lr_actor),
        "critic": torch.optim.Adam(critic.parameters(), lr=lr_critic),
    }

    loaders = {
        "train_game": DataLoader(
            ReplayDataset(replay_buffer, epoch_size=epoch_size), batch_size=batch_size
        ),
    }

    runner = CustomRunner(gamma=gamma, tau=tau, tau_period=tau_period)

    runner.train(
        engine=dl.CPUEngine(),  # for simplicity reasons, let's run everything on cpu
        model=models,
        criterion=criterion,
        optimizer=optimizer,
        loaders=loaders,
        logdir="./logs_ddpg",
        num_epochs=10,
        verbose=True,
        valid_loader="_epoch_",
        valid_metric="v_reward",
        minimize_valid_metric=False,
        load_best_on_end=True,
        callbacks=[
            GameCallback(
                env=env,
                replay_buffer=replay_buffer,