def train(checkpoint_dir: str, start_from_ckpnt: str = '', save_epoch_offset: int = 0): # Construct the dataset dataset_train, train_config = construct_dataset(is_train=True) # dataset_val, val_config = construct_dataset(is_train=False) # And the dataloader loader_train = DataLoader(dataset=dataset_train, batch_size=1, shuffle=True, num_workers=1) # loader_val = DataLoader(dataset=dataset_val, batch_size=16, shuffle=False, num_workers=4) # Construct the regressor network, net_config = construct_network() #control_network = ControlNetwork(in_channel=int(net_config.num_keypoints * net_config.depth_per_keypoint * 256/4 * 256/4)) if len(start_from_ckpnt) > 0: network.load_state_dict(torch.load(start_from_ckpnt)) else: init_from_modelzoo(network, net_config) network.to(device) #control_network.to(device) # The checkpoint if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) # root mean square error loss criterion_rmse = RMSELoss() criterion_cos = torch.nn.CosineSimilarity(dim=1) criterion_bce = torch.nn.BCELoss(reduction='none') # The optimizer and scheduler optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 90], gamma=0.1) # The training loop for epoch in range(1, n_epoch + 1): # Save the network if epoch % 20 == 0 and epoch > 0: file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset) checkpoint_path = os.path.join(checkpoint_dir, file_name) print('Save the network at %s' % checkpoint_path) torch.save(network.state_dict(), checkpoint_path) # Prepare info for training network.train() train_error_xy = 0 train_error_depth = 0 train_error_move = 0 train_error_rot = 0 train_error_xyz = 0 train_error_step = 0 # The learning rate step scheduler.step() for param_group in optimizer.param_groups: print('The learning rate is ', param_group['lr']) # The training iteration over the dataset for idx, data in enumerate(loader_train): # Get the data image = data[parameter.rgbd_image_key][0] keypoint_xy_depth = data[parameter.keypoint_xyd_key][0] keypoint_weight = data[parameter.keypoint_validity_key][0] delta_rot = data[parameter.delta_rot_key][0] delta_xyz = data[parameter.delta_xyz_key][0] gripper_pose = data[parameter.gripper_pose_key][0] step_size = data[parameter.step_size_key][0] # Upload to GPU image = image.to(device) keypoint_xy_depth = keypoint_xy_depth.to(device) keypoint_weight = keypoint_weight.to(device) delta_rot = delta_rot.to(device) delta_xyz = delta_xyz.to(device) gripper_pose = gripper_pose.to(device) step_size = step_size.to(device) #print('delta_rot',delta_rot.shape) #print('delta_xyz',delta_xyz.shape) #print('gripper_pose',gripper_pose.shape) #print('step_size',step_size.shape) # To predict optimizer.zero_grad() # raw_pred (batch_size, num_keypoint*2, network_out_map_height, network_out_map_width) # prob_pred (batch_size, num_keypoint, network_out_map_height, network_out_map_width) # depthmap_pred (batch_size, num_keypoint, network_out_map_height, network_out_map_width) xy_depth_pred, delta_rot_pred, delta_xyz_pred, step_size_pred = network( image, gripper_pose, device, enableKeypointPos=enableKeypointPos) #print((1-criterion_cos(torch.tensor([[0.01,0.01,0.01],[0.01,0.01,0.01]]).to(device), torch.tensor([[0.0,0.0,0.0],[0.0,0.0,0.0]]).to(device))).mean()) #gripper control network #raw_pred_flatten = torch.flatten(raw_pred, start_dim=1) #delta_rot_pred, delta_xyz_pred, step_size_pred = control_network(raw_pred_flatten) #identity = torch.eye(3).unsqueeze(0).repeat(image.shape[0],1,1).to(device) #identity_hat = torch.matmul(torch.transpose(delta_rot_pred, 1, 2), delta_rot) #loss_r = criterion_rmse(identity, identity_hat) loss_r = criterion_rmse(delta_rot_pred, delta_rot) loss_t = (1 - criterion_cos(delta_xyz_pred, delta_xyz) ).mean() + criterion_rmse(delta_xyz_pred, delta_xyz) #loss_t = criterion_rmse(delta_xyz_pred, delta_xyz) loss_s = criterion_bce(step_size_pred, step_size) loss_s = loss_s.mean() ''' prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :] depthmap_pred = raw_pred[:, net_config.num_keypoints:, :, :] # heatmap (batch_size, num_keypoint, network_out_map_height, network_out_map_width) heatmap = predict.heatmap_from_predict(prob_pred, net_config.num_keypoints) _, _, heatmap_height, heatmap_width = heatmap.shape #print(raw_pred.shape) #print(prob_pred.shape) #print(depthmap_pred.shape) # Compute the coordinate if device == 'cpu': coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_cpu(heatmap, net_config.num_keypoints) else: coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(heatmap, net_config.num_keypoints) depth_pred = predict.depth_integration(heatmap, depthmap_pred) # Concantate them xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2) ''' # Compute loss loss_kpt = weighted_l1_loss(xy_depth_pred, keypoint_xy_depth, keypoint_weight) loss = loss_kpt * 15 + loss_r * 10 + loss_t * 10 + loss_s loss.backward() optimizer.step() # Log info xy_error = float( weighted_l1_loss(xy_depth_pred[:, :, 0:2], keypoint_xy_depth[:, :, 0:2], keypoint_weight[:, :, 0:2]).item()) depth_error = float( weighted_l1_loss(xy_depth_pred[:, :, 2], keypoint_xy_depth[:, :, 2], keypoint_weight[:, :, 2]).item()) ''' if idx % 100 == 0: print('Iteration %d in epoch %d' % (idx, epoch)) print('The averaged pixel error is (pixel in 256x256 image): ', 256 * xy_error / len(xy_depth_pred)) print('The averaged depth error is (mm): ', 256 * depth_error / len(xy_depth_pred)) print('The move error is', loss_move.item()) ''' # Update info train_error_xy += float(xy_error) train_error_depth += float(depth_error) train_error_rot += float(loss_r) train_error_xyz += float(loss_t) train_error_step += float(loss_s) # cleanup del loss # The info at epoch level print('Epoch %d' % epoch) print( 'The training averaged pixel error is (pixel in 256x256 image): ', 256 * train_error_xy / len(dataset_train)) print( 'The training averaged depth error is (mm): ', train_config.depth_image_scale * train_error_depth / len(dataset_train)) #print('The training averaged move error is: ', train_error_move / len(dataset_train)) print('The training averaged rot error is: ', train_error_rot / len(dataset_train)) print('The training averaged xyz error is: ', train_error_xyz / len(dataset_train)) print('The training averaged step error is: ', train_error_step / len(dataset_train)) writer.add_scalar('average pixel error', 256 * train_error_xy / len(dataset_train), epoch) writer.add_scalar( 'average depth error', train_config.depth_image_scale * train_error_depth / len(dataset_train), epoch) writer.add_scalar('average move error', train_error_move / len(dataset_train), epoch) writer.add_scalar('average rot error', train_error_rot / len(dataset_train), epoch) writer.add_scalar('average xyz error', train_error_xyz / len(dataset_train), epoch) writer.add_scalar('average step error', train_error_step / len(dataset_train), epoch) writer.close()
def train(checkpoint_dir: str, start_from_ckpnt: str = '', save_epoch_offset: int = 0): # Construct the dataset dataset_train, train_config = construct_dataset(is_train=True) # dataset_val, val_config = construct_dataset(is_train=False) # And the dataloader loader_train = DataLoader(dataset=dataset_train, batch_size=64, shuffle=True, num_workers=4) # loader_val = DataLoader(dataset=dataset_val, batch_size=16, shuffle=False, num_workers=4) # Construct the regressor network, net_config = construct_network() if len(start_from_ckpnt) > 0: network.load_state_dict(torch.load(start_from_ckpnt)) else: init_from_modelzoo(network, net_config) network.cuda() # The checkpoint if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) # The optimizer and scheduler optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 90], gamma=0.1) # The loss for heatmap heatmap_criterion = torch.nn.MSELoss().cuda() #heatmap_criterion = torch.nn.KLDivLoss().cuda() # The training loop for epoch in range(n_epoch): # Save the network if epoch % 20 == 0 and epoch > 0: file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset) checkpoint_path = os.path.join(checkpoint_dir, file_name) print('Save the network at %s' % checkpoint_path) torch.save(network.state_dict(), checkpoint_path) # Prepare info for training network.train() train_error_xy = 0 train_error_depth = 0 # The learning rate step scheduler.step() for param_group in optimizer.param_groups: print('The learning rate is ', param_group['lr']) # The training iteration over the dataset for idx, data in enumerate(loader_train): # Get the data image = data[parameter.rgbd_image_key] keypoint_xy_depth = data[parameter.keypoint_xyd_key] #keypoint_xy_depth (batch_size, num_keypoint, xydepth) keypoint_weight = data[parameter.keypoint_validity_key] target_heatmap = data[parameter.target_heatmap_key] #if idx == 0: # np.save('rgbd.npy', image[0]) # np.save('target_heatmap.npy', target_heatmap[0]) # Upload to GPU image = image.cuda() keypoint_xy_depth = keypoint_xy_depth.cuda() keypoint_weight = keypoint_weight.cuda() target_heatmap = target_heatmap.cuda() # To predict optimizer.zero_grad() # raw_pred (batch_size, num_keypoint*2, network_out_map_height, network_out_map_width) # prob_pred (batch_size, num_keypoint, network_out_map_height, network_out_map_width) # depthmap_pred (batch_size, num_keypoint, network_out_map_height, network_out_map_width) raw_pred = network(image) prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :] depthmap_pred = raw_pred[:, net_config.num_keypoints:, :, :] # heatmap (batch_size, num_keypoint, network_out_map_height, network_out_map_width) heatmap = predict.heatmap_from_predict(prob_pred, net_config.num_keypoints) #heatmap = prob_pred _, _, heatmap_height, heatmap_width = heatmap.shape #print(raw_pred.shape) #print(prob_pred.shape) #print(depthmap_pred.shape) # Compute the coordinate #coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(heatmap, net_config.num_keypoints) depth_pred = predict.depth_integration(heatmap, depthmap_pred) depth_pred = depth_pred[:, :, 0] # Concantate them #xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2) # Compute loss depth_loss = weighted_l1_loss(depth_pred, keypoint_xy_depth[:, :, 2], keypoint_weight[:, :, 2]) heatmap_loss = heatmap_criterion(heatmap, target_heatmap) np.save('pred_heatmap.npy', heatmap.cpu().detach().numpy()) np.save('target_heatmap.npy', target_heatmap.cpu().detach().numpy()) if idx % 100 == 0: print('depth loss:', depth_loss) print('heatmap loss:', heatmap_loss) #loss = weighted_l1_loss(xy_depth_pred, keypoint_xy_depth, keypoint_weight) #loss = depth_loss + 1500*heatmap_loss loss = heatmap_loss loss.backward() optimizer.step() # cleanup del loss # Log info #xy_error = float(weighted_l1_loss(xy_depth_pred[:, :, 0:2], keypoint_xy_depth[:, :, 0:2], keypoint_weight[:, :, 0:2]).item()) #depth_error = float(weighted_l1_loss(xy_depth_pred[:, :, 2], keypoint_xy_depth[:, :, 2], keypoint_weight[:, :, 2]).item()) keypoint_xy_pred, _ = predict.heatmap2d_to_normalized_imgcoord_argmax( heatmap) xy_error = float( weighted_l1_loss(keypoint_xy_pred[:, :, 0:2], keypoint_xy_depth[:, :, 0:2], keypoint_weight[:, :, 0:2]).item()) depth_error = float( weighted_l1_loss(depth_pred, keypoint_xy_depth[:, :, 2], keypoint_weight[:, :, 2]).item()) if idx % 100 == 0: print('Iteration %d in epoch %d' % (idx, epoch)) print('The averaged pixel error is (pixel in 256x256 image): ', 256 * xy_error / image.shape[0]) print( 'The averaged depth error is (mm): ', train_config.depth_image_scale * depth_error / len(depth_pred)) # Update info train_error_xy += float(xy_error) train_error_depth += float(depth_error) # The info at epoch level print('Epoch %d' % epoch) print( 'The training averaged pixel error is (pixel in 256x256 image): ', 256 * train_error_xy / len(dataset_train)) print( 'The training averaged depth error is (mm): ', train_config.depth_image_scale * train_error_depth / len(dataset_train)) writer.add_scalar('average pixel error', 256 * train_error_xy / len(dataset_train), epoch) writer.add_scalar( 'average depth error', train_config.depth_image_scale * train_error_depth / len(dataset_train), epoch) writer.close()
def train(checkpoint_dir: str, start_from_ckpnt: str = '', save_epoch_offset: int = 0): # Construct the dataset dataset_train, train_config = construct_dataset(is_train=True) dataset_val, val_config = construct_dataset(is_train=False) # And the dataloader loader_train = DataLoader(dataset=dataset_train, batch_size=16, shuffle=True, num_workers=4) loader_val = DataLoader(dataset=dataset_val, batch_size=16, shuffle=False, num_workers=4) # Construct the regressor network, net_config = construct_network() if len(start_from_ckpnt) > 0: network.load_state_dict(torch.load(start_from_ckpnt)) else: init_from_modelzoo(network, net_config) network.cuda() # The checkpoint if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) # The optimizer and scheduler optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [20, 40], gamma=0.1) # The loss for heatmap heatmap_criterion = torch.nn.MSELoss().cuda() # The training loop for epoch in range(n_epoch): # Save the network if epoch % 100 == 0 and epoch > 0: file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset) checkpoint_path = os.path.join(checkpoint_dir, file_name) print('Save the network at %s' % checkpoint_path) torch.save(network.state_dict(), checkpoint_path) # Prepare info for training network.train() train_error_xy = 0 # The learning rate step scheduler.step() for param_group in optimizer.param_groups: print('The learning rate is ', param_group['lr']) # The training iteration over the dataset for idx, data in enumerate(loader_train): # Get the data image = data[parameter.rgbd_image_key] keypoint_xy_depth = data[parameter.keypoint_xyd_key] keypoint_weight = data[parameter.keypoint_validity_key] target_heatmap = data[parameter.target_heatmap_key] # Upload to GPU image = image.cuda() keypoint_xy_depth = keypoint_xy_depth.cuda() keypoint_weight = keypoint_weight.cuda() target_heatmap = target_heatmap.cuda() # To predict optimizer.zero_grad() raw_pred = network(image) prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :] _, _, heatmap_height, heatmap_width = prob_pred.shape # Compute loss loss = heatmap_criterion(prob_pred, target_heatmap) # Do update loss.backward() optimizer.step() # cleanup del loss # Do some pred and log keypoint_xy_pred, _ = predict.heatmap2d_to_normalized_imgcoord_argmax( prob_pred) xy_error = float( weighted_l1_loss(keypoint_xy_pred[:, :, 0:2], keypoint_xy_depth[:, :, 0:2], keypoint_weight[:, :, 0:2]).item()) if idx % 100 == 0: print('Iteration %d in epoch %d' % (idx, epoch)) print('The averaged pixel error is (pixel in 256x256 image): ', 256 * xy_error / image.shape[0]) # Update info train_error_xy += float(xy_error) # The info at epoch level print('Epoch %d' % epoch) print( 'The training averaged pixel error is (pixel in 256x256 image): ', 256 * train_error_xy / len(dataset_train)) # Prepare info at epoch level network.eval() val_error_xy = 0 # The validation iteration of the data for idx, data in enumerate(loader_val): # Get the data image = data[parameter.rgbd_image_key] keypoint_xy_depth = data[parameter.keypoint_xyd_key] keypoint_weight = data[parameter.keypoint_validity_key] image = image.cuda() keypoint_xy_depth = keypoint_xy_depth.cuda() keypoint_weight = keypoint_weight.cuda() # To predict pred = network(image) prob_pred = pred[:, 0:net_config.num_keypoints, :, :] _, _, heatmap_height, heatmap_width = prob_pred.shape # Compute the coordinate keypoint_xy_pred, _ = predict.heatmap2d_to_normalized_imgcoord_argmax( prob_pred) xy_error = float( weighted_l1_loss(keypoint_xy_pred[:, :, 0:2], keypoint_xy_depth[:, :, 0:2], keypoint_weight[:, :, 0:2]).item()) # Update info val_error_xy += float(xy_error) # The info at epoch level print( 'The validation averaged pixel error is (pixel in 256x256 image): ', 256 * val_error_xy / len(dataset_val))
def train(checkpoint_dir: str, start_from_ckpnt: str = '', save_epoch_offset: int = 0): # Construct the dataset dataset_train, train_config = construct_dataset(is_train=True) # dataset_val, val_config = construct_dataset(is_train=False) # And the dataloader loader_train = DataLoader(dataset=dataset_train, batch_size=8, shuffle=True, num_workers=4) # loader_val = DataLoader(dataset=dataset_val, batch_size=16, shuffle=False, num_workers=4) # Construct the regressor network, net_config = construct_network() # To cuda network = torch.nn.DataParallel(network).cuda() if len(start_from_ckpnt) > 0: network.load_state_dict(torch.load(start_from_ckpnt)) # The checkpoint if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) # The optimizer and scheduler optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 80], gamma=0.1) # The training loop for epoch in range(n_epoch): # Save the network if epoch % 2 == 0 and epoch > 0: file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset) checkpoint_path = os.path.join(checkpoint_dir, file_name) print('Save the network at %s' % checkpoint_path) torch.save(network.state_dict(), checkpoint_path) # Prepare info for training network.train() train_error_xy = 0 train_error_depth = 0 train_error_xy_internal = 0 # The learning rate step scheduler.step() for param_group in optimizer.param_groups: print('The learning rate is ', param_group['lr']) # The training iteration over the dataset for idx, data in enumerate(loader_train): # Get the data image = data[parameter.rgbd_image_key] keypoint_xy_depth = data[parameter.keypoint_xyd_key] keypoint_weight = data[parameter.keypoint_validity_key] # Move to gpu image = image.cuda() keypoint_xy_depth = keypoint_xy_depth.cuda() keypoint_weight = keypoint_weight.cuda() # To predict optimizer.zero_grad() raw_pred = network(image) # The last stage raw_pred_last = raw_pred[-1] prob_pred_last = raw_pred_last[:, 0:net_config.num_keypoints, :, :] depthmap_pred_last = raw_pred_last[:, net_config.num_keypoints:, :, :] heatmap_last = predict.heatmap_from_predict( prob_pred_last, net_config.num_keypoints) _, _, heatmap_height, heatmap_width = heatmap_last.shape # Compute the coordinate coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu( heatmap_last, net_config.num_keypoints) depth_pred = predict.depth_integration(heatmap_last, depthmap_pred_last) # Concantate them xy_depth_pred = torch.cat([coord_x, coord_y, depth_pred], dim=2) # Compute loss loss = weighted_l1_loss(xy_depth_pred, keypoint_xy_depth, keypoint_weight) # For all other layers for stage_i in range(len(raw_pred) - 1): prob_pred_i = raw_pred[ stage_i] # Only 2d prediction on previous layers assert prob_pred_i.shape == prob_pred_last.shape heatmap_i = predict.heatmap_from_predict( prob_pred_i, net_config.num_keypoints) coord_x_i, coord_y_i = predict.heatmap2d_to_normalized_imgcoord_gpu( heatmap_i, net_config.num_keypoints) xy_pred_i = torch.cat([coord_x_i, coord_y_i], dim=2) loss = loss + weighted_l1_loss(xy_pred_i, keypoint_xy_depth[:, :, 0:2], keypoint_weight[:, :, 0:2]) # The SGD step loss.backward() optimizer.step() del loss # Log info xy_error = float( weighted_l1_loss(xy_depth_pred[:, :, 0:2], keypoint_xy_depth[:, :, 0:2], keypoint_weight[:, :, 0:2]).item()) depth_error = float( weighted_l1_loss(xy_depth_pred[:, :, 2], keypoint_xy_depth[:, :, 2], keypoint_weight[:, :, 2]).item()) # The error of internal stage keypoint_xy_pred_internal, _ = predict.heatmap2d_to_normalized_imgcoord_argmax( raw_pred[0]) xy_error_internal = float( weighted_l1_loss(keypoint_xy_pred_internal[:, :, 0:2], keypoint_xy_depth[:, :, 0:2], keypoint_weight[:, :, 0:2]).item()) if idx % 100 == 0: print('Iteration %d in epoch %d' % (idx, epoch)) print('The averaged pixel error is (pixel in 256x256 image): ', 256 * xy_error / len(xy_depth_pred)) print('The averaged depth error is (mm): ', 256 * depth_error / len(xy_depth_pred)) print( 'The averaged internal pixel error is (pixel in 256x256 image): ', 256 * xy_error_internal / image.shape[0]) # Update info train_error_xy += float(xy_error) train_error_depth += float(depth_error) train_error_xy_internal += float(xy_error_internal) # The info at epoch level print('Epoch %d' % epoch) print( 'The training averaged pixel error is (pixel in 256x256 image): ', 256 * train_error_xy / len(dataset_train)) print( 'The training averaged depth error is (mm): ', train_config.depth_image_scale * train_error_depth / len(dataset_train)) print( 'The training internal averaged pixel error is (pixel in 256x256 image): ', 256 * train_error_xy_internal / len(dataset_train))
def train(checkpoint_dir: str, start_from_ckpnt: str = '', save_epoch_offset: int = 0): # Construct the dataset dataset_train, train_config = construct_dataset(is_train=True) # And the dataloader loader_train = DataLoader(dataset=dataset_train, batch_size=32, shuffle=True, num_workers=4) # Construct the regressor network, net_config = construct_network() if len(start_from_ckpnt) > 0: network.load_state_dict(torch.load(start_from_ckpnt)) else: init_from_modelzoo(network, net_config) network.cuda() # The checkpoint if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) # The optimizer and scheduler optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 90], gamma=0.1) # The loss for heatmap heatmap_criterion = torch.nn.MSELoss().cuda() # The training loop for epoch in range(n_epoch): # Save the network if epoch % 4 == 0 and epoch > 0: file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset) checkpoint_path = os.path.join(checkpoint_dir, file_name) print('Save the network at %s' % checkpoint_path) torch.save(network.state_dict(), checkpoint_path) # Prepare info for training network.train() train_error_xy = 0 train_error_depth = 0 train_error_xy_heatmap = 0 # The learning rate step scheduler.step() for param_group in optimizer.param_groups: print('The learning rate is ', param_group['lr']) # The training iteration over the dataset for idx, data in enumerate(loader_train): # Get the data image = data[parameter.rgbd_image_key] keypoint_xy_depth = data[parameter.keypoint_xyd_key] keypoint_weight = data[parameter.keypoint_validity_key] target_heatmap = data[parameter.target_heatmap_key] # Upload to cuda image = image.cuda() keypoint_xy_depth = keypoint_xy_depth.cuda() keypoint_weight = keypoint_weight.cuda() target_heatmap = target_heatmap.cuda() # To predict optimizer.zero_grad() raw_pred = network(image) prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :] depthmap_pred = raw_pred[:, net_config.num_keypoints:2 * net_config.num_keypoints, :, :] regress_heatmap = raw_pred[:, 2 * net_config.num_keypoints:, :, :] integral_heatmap = predict.heatmap_from_predict( prob_pred, net_config.num_keypoints) _, _, heatmap_height, heatmap_width = integral_heatmap.shape # Compute the coordinate coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu( integral_heatmap, net_config.num_keypoints) depth_pred = predict.depth_integration(integral_heatmap, depthmap_pred) # Concantate them xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2) # Compute loss loss = weighted_mse_loss(xy_depth_pred, keypoint_xy_depth, keypoint_weight) loss = loss + heatmap_loss_weight * heatmap_criterion( regress_heatmap, target_heatmap) # Do update loss.backward() optimizer.step() # Log info xy_error = float( weighted_l1_loss(xy_depth_pred[:, :, 0:2], keypoint_xy_depth[:, :, 0:2], keypoint_weight[:, :, 0:2]).item()) depth_error = float( weighted_l1_loss(xy_depth_pred[:, :, 2], keypoint_xy_depth[:, :, 2], keypoint_weight[:, :, 2]).item()) keypoint_xy_pred_heatmap, _ = predict.heatmap2d_to_normalized_imgcoord_argmax( regress_heatmap) xy_error_heatmap = float( weighted_l1_loss(keypoint_xy_pred_heatmap[:, :, 0:2], keypoint_xy_depth[:, :, 0:2], keypoint_weight[:, :, 0:2]).item()) if idx % 100 == 0: print('Iteration %d in epoch %d' % (idx, epoch)) print('The averaged pixel error is (pixel in 256x256 image): ', 256 * xy_error / len(xy_depth_pred)) print( 'The averaged depth error is (mm): ', train_config.depth_image_scale * depth_error / len(xy_depth_pred)) print( 'The averaged heatmap argmax pixel error is (pixel in 256x256 image): ', 256 * xy_error_heatmap / len(xy_depth_pred)) # Update info train_error_xy += float(xy_error) train_error_depth += float(depth_error) train_error_xy_heatmap += float(xy_error_heatmap) # The info at epoch level print('Epoch %d' % epoch) print( 'The training averaged pixel error is (pixel in 256x256 image): ', 256 * train_error_xy / len(dataset_train)) print( 'The training averaged depth error is (mm): ', train_config.depth_image_scale * train_error_depth / len(dataset_train)) print( 'The training averaged heatmap pixel error is (pixel in 256x256 image): ', 256 * train_error_xy_heatmap / len(dataset_train))