def main(weights_uri, onnx_name): model = KeypointNet(7, (80, 80), onnx_mode=True) weights_path = weights_uri model.load_state_dict( torch.load(weights_path, map_location='cpu').get('model')) torch.onnx.export(model, torch.randn(1, 3, 80, 80), onnx_name) print("onnx file conversion succeed and saved at: " + onnx_name)
def __init__(self): self.weights = "./src/akhenaten_dv/scripts/Perception/KPDetection/weights.pt" self.img_size = 512 self.model = KeypointNet() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.model.to(device) model_dict = torch.load(self.weights).get('model') for k in model_dict.keys(): model_dict[k] = model_dict[k].cuda() self.model.load_state_dict(model_dict) self.model.eval() self.it = 0 self.image_patches = []
def main(weights_uri, onnx_name, opset): model = KeypointNet(7, (80, 80), onnx_mode=True) weights_path = weights_uri model.load_state_dict( torch.load(weights_path, map_location='cpu').get('model')) dummy_input = torch.randn(10, 3, 80, 80) input_names = ["actual_input_1"] output_names = ["output1"] # Fixed shape torch.onnx.export(model, dummy_input, onnx_name, verbose=True, opset_version=opset, input_names=input_names, output_names=output_names) # Dynamic shape dynamic_axes = { "actual_input_1": { 0: "batch_size" }, "output1": { 0: "batch_size" } } print(dynamic_axes) torch.onnx.export(model, dummy_input, "keypoints_dynamic.onnx", verbose=True, opset_version=opset, input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes) print("onnx file conversion succeed and saved at: " + onnx_name)
def main(model, img, img_size, output, flip, rotate): output_path = output model_path = model model_filepath = model_path image_path = img image_filepath = image_path img_name = '_'.join( image_filepath.split('/')[-1].split('.')[0].split('_')[-5:]) image_size = (img_size, img_size) image = cv2.imread(image_filepath) h, w, _ = image.shape image = prep_image(image=image, target_image_size=image_size) image = (image.transpose((2, 0, 1)) / 255.0)[np.newaxis, :] image = torch.from_numpy(image).type('torch.FloatTensor') model = KeypointNet() model.load_state_dict(torch.load(model_filepath).get('model')) model.eval() output = model(image) out = np.empty(shape=(0, output[0][0].shape[2])) for o in output[0][0]: chan = np.array(o.cpu().data) cmin = chan.min() cmax = chan.max() chan -= cmin chan /= cmax - cmin out = np.concatenate((out, chan), axis=0) cv2.imwrite(output_path + img_name + "_hm.jpg", out * 255) print( f'please check the output image here: {output_path + img_name + "_hm.jpg", out * 255}' ) image = cv2.imread(image_filepath) h, w, _ = image.shape vis_tensor_and_save(image=image, h=h, w=w, tensor_output=output[1][0].cpu().data, image_name=img_name, output_uri=output_path)
def train_model(model, output_uri, dataloader, loss_function, optimizer, scheduler, epochs, val_dataloader, intervals, input_size, num_kpt, save_checkpoints, kpt_keys, study_name, evaluate_mode): best_val_loss = float('inf') best_epoch = 0 max_tolerance = 8 tolerance = 0 for epoch in range(epochs): print(f"EPOCH {epoch}") model.train() total_loss = [0,0,0] # location/geometric/total batch_num = 0 train_process = tqdm(dataloader) for x_batch, y_hm_batch, y_points_batch, image_name, _ in train_process: x_batch = x_batch.to(device) y_hm_batch = y_hm_batch.to(device) y_points_batch = y_points_batch.to(device) # Zero the gradients. if optimizer is not None: optimizer.zero_grad() # Compute output and loss. output = model(x_batch) loc_loss, geo_loss, loss = loss_function(output[0], output[1], y_hm_batch, y_points_batch) loss.backward() optimizer.step() loc_loss, geo_loss, loss = loc_loss.item(), geo_loss.item(), loss.item() train_process.set_description(f"Batch {batch_num}. Location Loss: {round(loc_loss,5)}. Geo Loss: {round(geo_loss,5)}. Total Loss: {round(loss,5)}") total_loss[0] += loc_loss total_loss[1] += geo_loss total_loss[2] += loss batch_num += 1 print(f"\tTraining: MSE/Geometric/Total Loss: {round(total_loss[0]/batch_num,10)}/{round(total_loss[1]/batch_num,10)}/{round(total_loss[2]/batch_num,10)}") val_loc_loss, val_geo_loss, val_loss = eval_model(model=model, dataloader=val_dataloader, loss_function=loss_function, input_size=input_size) # Position suggested by https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate scheduler.step() if val_loss < best_val_loss: best_val_loss = val_loss best_epoch = epoch tolerance = 0 # Save model onnx for inference. if save_checkpoints: onnx_uri = os.path.join(output_uri,f"best_keypoints_{input_size[0]}{input_size[1]}.onnx") onnx_model = KeypointNet(num_kpt, input_size, onnx_mode=True) onnx_model.load_state_dict(model.state_dict()) torch.onnx.export(onnx_model, torch.randn(1, 3, input_size[0], input_size[1]), onnx_uri) print(f"Saving ONNX model to {onnx_uri}") best_model = copy.deepcopy(model) else: tolerance += 1 if save_checkpoints and epoch != 0 and (epoch + 1) % intervals == 0: # Save the latest weights gs_pt_uri = os.path.join(output_uri, "{epoch}_loss_{loss}.pt".format(epoch=epoch, loss=round(val_loss, 2))) print(f'Saving model to {gs_pt_uri}') checkpoint = {'epoch': epoch, 'model': model.state_dict(), 'optimizer': optimizer.state_dict()} torch.save(checkpoint, gs_pt_uri) if tolerance >= max_tolerance: print(f"Training is stopped due; loss no longer decreases. Epoch {best_epoch} is has the best validation loss.") break
def main(): def add_bool_arg(name, default, help): arg_group = parser.add_mutually_exclusive_group(required=False) arg_group.add_argument('--' + name, dest=name, action='store_true', help=help) arg_group.add_argument('--no_' + name, dest=name, action='store_false', help=("Do not " + help)) parser.set_defaults(**{name:default}) parser = argparse.ArgumentParser(description='Keypoints Training with Pytorch') parser.add_argument('--input_size', default=80, help='input image size') parser.add_argument('--train_dataset_uri', default='dataset/rektnet_label.csv', help='training dataset csv directory path') parser.add_argument('--output_path', type=str, help='output weights path, by default we will create a folder based on current system time and name of your cfg file',default="automatic") parser.add_argument('--dataset_path', type=str, help='path to image dataset',default="dataset/RektNet_Dataset/") parser.add_argument('--loss_type', default='l1_softargmax', help='loss type: l2_softargmax|l2_heatmap|l1_softargmax') parser.add_argument('--validation_ratio', default=0.15, type=float, help='percent of dataset to use for validation') parser.add_argument('--batch_size', type=int, default=32, help='size of each image batch') parser.add_argument('--lr', '--learning-rate', default=1e-1, type=float, help='learning rate') parser.add_argument('--lr_gamma', default=0.999, help='gamma for the scheduler') parser.add_argument('--num_epochs', default=1024, type=int, help='number of epochs') parser.add_argument("--checkpoint_interval", type=int, default=4, help="interval between saving model weights") parser.add_argument('--study_name', required=True, help='name for saving checkpoint models') add_bool_arg('geo_loss', default=True, help='whether to add in geo loss') parser.add_argument('--geo_loss_gamma_vert', default=0, type=float, help='gamma for the geometric loss (horizontal)') parser.add_argument('--geo_loss_gamma_horz', default=0, type=float, help='gamma for the geometric loss (vertical)') add_bool_arg('vis_upload_data', default=False, help='whether to visualize our dataset in Christmas Tree format and upload the whole dataset to. default to False') add_bool_arg('save_checkpoints', default=True, help='whether to save checkpoints') add_bool_arg('vis_dataloader', default=False, help='whether to visualize the image points and heatmap processed in our dataloader') add_bool_arg('evaluate_mode', default=False, help='whether to evaluate avg kpt mse vs BB size distribution at end of training') args = parser.parse_args() print("Program arguments:", args) if args.output_path == "automatic": current_month = datetime.now().strftime('%B').lower() current_year = str(datetime.now().year) if not os.path.exists(os.path.join('outputs/', current_month + '-' + current_year + '-experiments/' + args.study_name + '/')): os.makedirs(os.path.join('outputs/', current_month + '-' + current_year + '-experiments/' + args.study_name + '/')) output_uri = os.path.join('outputs/', current_month + '-' + current_year + '-experiments/' + args.study_name + '/') else: output_uri = args.output_path save_file_name = 'logs/' + output_uri.split('/')[-2] sys.stdout = Logger(save_file_name + '.log') sys.stderr = Logger(save_file_name + '.error') INPUT_SIZE = (args.input_size, args.input_size) KPT_KEYS = ["top", "mid_L_top", "mid_R_top", "mid_L_bot", "mid_R_bot", "bot_L", "bot_R"] intervals = args.checkpoint_interval val_split = args.validation_ratio batch_size= args.batch_size num_epochs= args.num_epochs # Load the train data. train_csv = args.train_dataset_uri train_images, train_labels, val_images, val_labels = load_train_csv_dataset(train_csv, validation_percent=val_split, keypoint_keys=KPT_KEYS, dataset_path=args.dataset_path, cache_location="./gs/") # "Become one with the data" - Andrej Karpathy if args.vis_upload_data: visualize_data(train_images, train_labels) print('Shutting down instance...') os.system('sudo shutdown now') # Create pytorch dataloaders for train and validation sets. train_dataset = ConeDataset(images=train_images, labels=train_labels, dataset_path=args.dataset_path, target_image_size=INPUT_SIZE, save_checkpoints=args.save_checkpoints, vis_dataloader=args.vis_dataloader) train_dataloader = DataLoader(train_dataset, batch_size= batch_size, shuffle=False, num_workers=0) val_dataset = ConeDataset(images=val_images, labels=val_labels, dataset_path=args.dataset_path, target_image_size=INPUT_SIZE, save_checkpoints=args.save_checkpoints, vis_dataloader=args.vis_dataloader) val_dataloader = DataLoader(val_dataset, batch_size= 1, shuffle=False, num_workers=0) # Define model, optimizer and loss function. model = KeypointNet(len(KPT_KEYS), INPUT_SIZE, onnx_mode=False) model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma) loss_func = CrossRatioLoss(args.loss_type, args.geo_loss, args.geo_loss_gamma_horz, args.geo_loss_gamma_vert) # Train our model. train_model( model=model, output_uri=output_uri, dataloader=train_dataloader, loss_function=loss_func, optimizer=optimizer, scheduler=scheduler, epochs=num_epochs, val_dataloader=val_dataloader, intervals=intervals, input_size=INPUT_SIZE, num_kpt=len(KPT_KEYS), save_checkpoints=args.save_checkpoints, kpt_keys=KPT_KEYS, study_name=args.study_name, evaluate_mode=args.evaluate_mode )