Exemplo n.º 1
0
def main():
    epochs = 5
    num_class = 10
    output_path = './output/catalyst'

    # Use if you want to fix seed
    # catalyst.utils.set_global_seed(42)
    # catalyst.utils.prepare_cudnn(deterministic=True)

    model = get_model()
    train_loader, val_loader = get_loaders()
    loaders = {"train": train_loader, "valid": val_loader}

    optimizer, lr_scheduler = get_optimizer(model=model)
    criterion = get_criterion()

    runner = SupervisedRunner(device=catalyst.utils.get_device())
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=lr_scheduler,
        loaders=loaders,
        logdir=output_path,
        callbacks=[AccuracyCallback(num_classes=num_class, accuracy_args=[1])],
        num_epochs=epochs,
        main_metric="accuracy01",
        minimize_metric=False,
        fp16=None,
        verbose=True
    )
Exemplo n.º 2
0
def main():
    epochs = 10
    train_loader, val_loader = get_loaders()
    model = get_model()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    optimizer, scheduler = get_optimizer(model)
    criterion = get_criterion()

    run(
        epochs=epochs,
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
    )
Exemplo n.º 3
0
def main():
    epochs = 10

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = get_model()
    train_loader, val_loader = get_loaders()
    optimizer, lr_scheduler = get_optimizer(model=model)
    criterion = get_criterion()

    # Model を multi-gpu したり、FP16 対応したりする
    model = model.to(device)

    print('Train start !')
    for epoch in range(epochs):
        print(f'epoch {epoch} start !')
        metrics_train = train(model, train_loader, criterion, optimizer,
                              device)
        metrics_eval = eval(model, val_loader, criterion, device)

        lr_scheduler.step()

        # Logger 周りの処理
        # print するためのごちゃごちゃした処理
        print(f'epoch: {epoch} ', metrics_train, metrics_eval)
Exemplo n.º 4
0
 def val_dataloader(self):
     # OPTIONAL
     return get_loaders()[1]
Exemplo n.º 5
0
 def train_dataloader(self):
     # REQUIRED
     return get_loaders()[0]