def main():
    rclpy.init(args=sys.argv)
    node = rclpy.create_node("iknet_inference")
    set_joint_position = node.create_client(SetJointPosition,
                                            "/goal_joint_space_path")

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model",
        type=str,
        default="./iknet.pth",
    )
    parser.add_argument("--trt", action="store_true", default=False)
    parser.add_argument("--x", type=float, default=0.1)
    parser.add_argument("--y", type=float, default=0.0)
    parser.add_argument("--z", type=float, default=0.1)
    parser.add_argument("--qx", type=float, default=0.0)
    parser.add_argument("--qy", type=float, default=0.0)
    parser.add_argument("--qz", type=float, default=0.0)
    parser.add_argument("--qw", type=float, default=1.0)
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if not args.trt:
        model = IKNet()
    else:
        from torch2trt import TRTModule
        model = TRTModule()
    model.to(device)
    model.load_state_dict(torch.load(args.model))
    model.eval()
    pose = [args.x, args.y, args.z, args.qx, args.qy, args.qz, args.qw]
    if not args.trt:
        input_ = torch.FloatTensor(pose)
    else:
        input_ = torch.FloatTensor([pose])
    input_ = input_.to(device)
    print(f"input: {input_}")
    output = model(input_)
    print(f"output: {output}")

    joint_position = JointPosition()
    joint_position.joint_name = [f"joint{i+1}" for i in range(4)]
    if not args.trt:
        joint_position.position = [output[i].item() for i in range(4)]
    else:
        joint_position.position = [output[0][i].item() for i in range(4)]
    request = SetJointPosition.Request()
    request.joint_position = joint_position
    request.path_time = 4.0

    future = set_joint_position.call_async(request)
    rclpy.spin_until_future_complete(node, future)
    if future.result() is not None:
        print(f"result: {future.result().is_planned}")
    else:
        print(f"exception: {future.exception()}")

    node.destroy_node()
    rclpy.shutdown()
예제 #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--kinematics-pose-csv",
                        type=str,
                        default="./dataset/test/kinematics_pose.csv")
    parser.add_argument("--joint-states-csv",
                        type=str,
                        default="./dataset/test/joint_states.csv")
    parser.add_argument("--batch-size", type=int, default=10000)
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = IKNet()
    model.load_state_dict(torch.load("iknet.pth"))
    model.to(device)
    model.eval()

    dataset = IKDataset(args.kinematics_pose_csv, args.joint_states_csv)
    test_loader = DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=False)

    total_loss = 0.0
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        total_loss += (output - target).norm().item() / args.batch_size
    print(f"Total loss = {total_loss}")
def objective(trial):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = IKNet(trial)
    model.to(device)
    train_loader, val_loader = get_data_loaders(args)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    trigger = ppe.training.triggers.EarlyStoppingTrigger(
        check_trigger=(3, "epoch"), monitor="val/loss")
    my_extensions = [
        extensions.LogReport(),
        extensions.ProgressBar(),
        extensions.observe_lr(optimizer=optimizer),
        extensions.ParameterStatistics(model, prefix="model"),
        extensions.VariableStatisticsPlot(model),
        extensions.Evaluator(
            val_loader,
            model,
            eval_func=lambda data, target: validate(args, model, device, data,
                                                    target),
            progress_bar=True,
        ),
        extensions.PlotReport(["train/loss", "val/loss"],
                              "epoch",
                              filename="loss.png"),
        extensions.PrintReport([
            "epoch",
            "iteration",
            "train/loss",
            "lr",
            "val/loss",
        ]),
    ]
    manager = ppe.training.ExtensionsManager(
        model,
        optimizer,
        args.epochs,
        extensions=my_extensions,
        iters_per_epoch=len(train_loader),
        stop_trigger=trigger,
    )
    return train(manager, args, model, device, train_loader)
예제 #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input-model",
        type=str,
        default="./iknet.pt",
    )
    parser.add_argument(
        "--output-model",
        type=str,
        default="./iknet.onnx",
    )
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = IKNet().to(device)
    model.load_state_dict(torch.load(args.input_model))
    model.eval()
    print(model)

    input_ = torch.ones(7).to(device)
    torch.onnx.export(
        model,
        input_,
        args.output_model,
        verbose=True,
        input_names=["pose"],
        output_names=["joints"],
    )
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--kinematics-pose-csv",
        type=str,
        default="./dataset/train/kinematics_pose.csv",
    )
    parser.add_argument("--joint-states-csv",
                        type=str,
                        default="./dataset/train/joint_states.csv")
    parser.add_argument("--train-val-ratio", type=float, default=0.8)
    parser.add_argument("--batch-size", type=int, default=10000)
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--save-model", action="store_true", default=False)
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = IKNet()
    model.to(device)
    train_loader, val_loader = get_data_loaders(args)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    trigger = ppe.training.triggers.EarlyStoppingTrigger(
        check_trigger=(3, "epoch"), monitor="val/loss")
    my_extensions = [
        extensions.LogReport(),
        extensions.ProgressBar(),
        extensions.observe_lr(optimizer=optimizer),
        extensions.ParameterStatistics(model, prefix="model"),
        extensions.VariableStatisticsPlot(model),
        extensions.Evaluator(
            val_loader,
            model,
            eval_func=lambda data, target: validate(args, model, device, data,
                                                    target),
            progress_bar=True,
        ),
        extensions.PlotReport(["train/loss", "val/loss"],
                              "epoch",
                              filename="loss.png"),
        extensions.PrintReport([
            "epoch",
            "iteration",
            "train/loss",
            "lr",
            "val/loss",
        ]),
    ]
    manager = ppe.training.ExtensionsManager(
        model,
        optimizer,
        args.epochs,
        extensions=my_extensions,
        iters_per_epoch=len(train_loader),
        stop_trigger=trigger,
    )
    train(manager, args, model, device, train_loader)

    if args.save_model:
        torch.save(model.state_dict(), "iknet.pt")
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--input-model",
        type=str,
        default="./iknet.pth",
    )
    parser.add_argument(
        "--output-model",
        type=str,
        default="./iknet-trt.pth",
    )
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = IKNet().to(device)
    model.load_state_dict(torch.load(args.input_model))
    model.eval()
    print(model)

    input_ = torch.ones(1, 7).to(device)
    model_trt = torch2trt(model, [input_], fp16_mode=True)
    print(model_trt)
    torch.save(model_trt.state_dict(), args.output_model)