def main(): parser = argparse.ArgumentParser() parser.add_argument( "--kinematics-pose-csv", type=str, default="./dataset/train/kinematics_pose.csv", ) parser.add_argument("--joint-states-csv", type=str, default="./dataset/train/joint_states.csv") parser.add_argument("--train-val-ratio", type=float, default=0.8) parser.add_argument("--batch-size", type=int, default=10000) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--lr", type=float, default=0.01) parser.add_argument("--save-model", action="store_true", default=False) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = IKNet() model.to(device) train_loader, val_loader = get_data_loaders(args) optimizer = optim.Adam(model.parameters(), lr=args.lr) trigger = ppe.training.triggers.EarlyStoppingTrigger( check_trigger=(3, "epoch"), monitor="val/loss") my_extensions = [ extensions.LogReport(), extensions.ProgressBar(), extensions.observe_lr(optimizer=optimizer), extensions.ParameterStatistics(model, prefix="model"), extensions.VariableStatisticsPlot(model), extensions.Evaluator( val_loader, model, eval_func=lambda data, target: validate(args, model, device, data, target), progress_bar=True, ), extensions.PlotReport(["train/loss", "val/loss"], "epoch", filename="loss.png"), extensions.PrintReport([ "epoch", "iteration", "train/loss", "lr", "val/loss", ]), ] manager = ppe.training.ExtensionsManager( model, optimizer, args.epochs, extensions=my_extensions, iters_per_epoch=len(train_loader), stop_trigger=trigger, ) train(manager, args, model, device, train_loader) if args.save_model: torch.save(model.state_dict(), "iknet.pt")
def objective(trial): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = IKNet(trial) model.to(device) train_loader, val_loader = get_data_loaders(args) optimizer = optim.Adam(model.parameters(), lr=args.lr) trigger = ppe.training.triggers.EarlyStoppingTrigger( check_trigger=(3, "epoch"), monitor="val/loss") my_extensions = [ extensions.LogReport(), extensions.ProgressBar(), extensions.observe_lr(optimizer=optimizer), extensions.ParameterStatistics(model, prefix="model"), extensions.VariableStatisticsPlot(model), extensions.Evaluator( val_loader, model, eval_func=lambda data, target: validate(args, model, device, data, target), progress_bar=True, ), extensions.PlotReport(["train/loss", "val/loss"], "epoch", filename="loss.png"), extensions.PrintReport([ "epoch", "iteration", "train/loss", "lr", "val/loss", ]), ] manager = ppe.training.ExtensionsManager( model, optimizer, args.epochs, extensions=my_extensions, iters_per_epoch=len(train_loader), stop_trigger=trigger, ) return train(manager, args, model, device, train_loader)