def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--kinematics-pose-csv",
        type=str,
        default="./dataset/train/kinematics_pose.csv",
    )
    parser.add_argument("--joint-states-csv",
                        type=str,
                        default="./dataset/train/joint_states.csv")
    parser.add_argument("--train-val-ratio", type=float, default=0.8)
    parser.add_argument("--batch-size", type=int, default=10000)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--save-model", action="store_true", default=False)
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = IKNet()
    model.to(device)
    train_loader, val_loader = get_data_loaders(args)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    trigger = ppe.training.triggers.EarlyStoppingTrigger(
        check_trigger=(3, "epoch"), monitor="val/loss")
    my_extensions = [
        extensions.LogReport(),
        extensions.ProgressBar(),
        extensions.observe_lr(optimizer=optimizer),
        extensions.ParameterStatistics(model, prefix="model"),
        extensions.VariableStatisticsPlot(model),
        extensions.Evaluator(
            val_loader,
            model,
            eval_func=lambda data, target: validate(args, model, device, data,
                                                    target),
            progress_bar=True,
        ),
        extensions.PlotReport(["train/loss", "val/loss"],
                              "epoch",
                              filename="loss.png"),
        extensions.PrintReport([
            "epoch",
            "iteration",
            "train/loss",
            "lr",
            "val/loss",
        ]),
    ]
    manager = ppe.training.ExtensionsManager(
        model,
        optimizer,
        args.epochs,
        extensions=my_extensions,
        iters_per_epoch=len(train_loader),
        stop_trigger=trigger,
    )
    train(manager, args, model, device, train_loader)

    if args.save_model:
        torch.save(model.state_dict(), "iknet.pt")
def objective(trial):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = IKNet(trial)
    model.to(device)
    train_loader, val_loader = get_data_loaders(args)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    trigger = ppe.training.triggers.EarlyStoppingTrigger(
        check_trigger=(3, "epoch"), monitor="val/loss")
    my_extensions = [
        extensions.LogReport(),
        extensions.ProgressBar(),
        extensions.observe_lr(optimizer=optimizer),
        extensions.ParameterStatistics(model, prefix="model"),
        extensions.VariableStatisticsPlot(model),
        extensions.Evaluator(
            val_loader,
            model,
            eval_func=lambda data, target: validate(args, model, device, data,
                                                    target),
            progress_bar=True,
        ),
        extensions.PlotReport(["train/loss", "val/loss"],
                              "epoch",
                              filename="loss.png"),
        extensions.PrintReport([
            "epoch",
            "iteration",
            "train/loss",
            "lr",
            "val/loss",
        ]),
    ]
    manager = ppe.training.ExtensionsManager(
        model,
        optimizer,
        args.epochs,
        extensions=my_extensions,
        iters_per_epoch=len(train_loader),
        stop_trigger=trigger,
    )
    return train(manager, args, model, device, train_loader)