コード例 #1
0
    def test_imgcoord_cpu(self):
        n_batch = 16
        n_keypoint = 6
        image_size = 256

        # Some random pred
        rand_pred = torch.rand(size=(n_batch, n_keypoint, image_size,
                                     image_size))
        from mankey.network.predict import heatmap_from_predict, heatmap2d_to_imgcoord_cpu
        heatmap = heatmap_from_predict(rand_pred, n_keypoint)
        coord_x, coord_y = heatmap2d_to_imgcoord_cpu(heatmap,
                                                     num_keypoints=n_keypoint)

        # Check the size
        self.assertEqual(coord_x.shape, (n_batch, n_keypoint, 1))
        self.assertEqual(coord_y.shape, (n_batch, n_keypoint, 1))

        # Check the value, method can be slow
        check_batch_idx = 0
        check_keypoint_idx = 0
        specific_heatmap = heatmap[check_batch_idx,
                                   check_keypoint_idx, :, :].numpy()
        x_value = 0
        y_value = 0
        for y_idx in range(image_size):
            for x_idx in range(image_size):
                x_value += specific_heatmap[y_idx, x_idx] * x_idx
                y_value += specific_heatmap[y_idx, x_idx] * y_idx

        # Compare with original value
        x_pred = float(coord_x[check_batch_idx, check_keypoint_idx, 0].item())
        y_pred = float(coord_y[check_batch_idx, check_keypoint_idx, 0].item())
        self.assertTrue(abs(x_value - x_pred) < 1e-4)
        self.assertTrue(abs(y_value - y_pred) < 1e-4)
コード例 #2
0
def inference_hourglass_staged(
    network,  # type: hourglass_staged.HourglassNet
    imgproc_out,  # type: ImageProcOut
):  # type: (hourglass_staged.HourglassNet, ImageProcOut) -> np.ndarray
    # Upload the image
    stacked_rgbd = torch.from_numpy(imgproc_out.stacked_rgbd)
    stacked_rgbd = torch.unsqueeze(stacked_rgbd, dim=0)
    stacked_rgbd = stacked_rgbd.cuda()

    # Do forward
    raw_pred_all = network(stacked_rgbd)
    raw_pred = raw_pred_all[-1]
    num_keypoints = raw_pred.shape[1] // 2
    assert raw_pred.shape[1] == 2 * num_keypoints
    prob_pred = raw_pred[:, 0:num_keypoints, :, :]
    depthmap_pred = raw_pred[:, num_keypoints:, :, :]
    heatmap = predict.heatmap_from_predict(prob_pred, num_keypoints)
    coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
        heatmap, num_keypoints)
    depth_pred = predict.depth_integration(heatmap, depthmap_pred)

    # The scaled image coord and depth
    coord_x = coord_x.cpu().detach().numpy()
    coord_y = coord_y.cpu().detach().numpy()
    depth_pred = depth_pred.cpu().detach().numpy()

    # Combine them
    keypointxy_depth_pred = np.zeros((3, num_keypoints))
    keypointxy_depth_pred[0, :] = coord_x[0, :, 0]
    keypointxy_depth_pred[1, :] = coord_y[0, :, 0]
    keypointxy_depth_pred[2, :] = depth_pred[0, :, 0]
    return keypointxy_depth_pred
コード例 #3
0
    def test_depth_integration(self):
        n_batch = 16
        n_keypoint = 6
        image_size = 256

        # Some random pred
        rand_pred = torch.rand(size=(n_batch, n_keypoint, image_size,
                                     image_size))
        from mankey.network.predict import heatmap_from_predict, depth_integration
        heatmap = heatmap_from_predict(rand_pred, n_keypoint)
        depth_pred = depth_integration(heatmap, rand_pred)

        # Check the value, method can be slow
        check_batch_idx = 0
        check_keypoint_idx = 0
        specific_heatmap = heatmap[check_batch_idx,
                                   check_keypoint_idx, :, :].numpy()
        specific_depthmap = rand_pred[check_batch_idx,
                                      check_keypoint_idx, :, :].numpy()
        depth_value = 0
        for y_idx in range(image_size):
            for x_idx in range(image_size):
                depth_value += specific_heatmap[
                    y_idx, x_idx] * specific_depthmap[y_idx, x_idx]

        # Compare with the original value
        d_pred = float(depth_pred[check_batch_idx, check_keypoint_idx,
                                  0].item())
        self.assertTrue(abs(depth_value - d_pred) < 1e-4)
コード例 #4
0
 def forward(self, x, gripper_pose, device, enableKeypointPos=True):
     # x_feature's size is (batch, 512, 8, 8)
     x_feature = self.backbone_net(x)
     x_heatmap = self.head_net(x_feature)
     x_feature_flatten = torch.flatten(x_feature, start_dim=1)
     # keypoint  
     prob_pred = x_heatmap[:, 0:self.config.num_keypoints, :, :]
     depthmap_pred = x_heatmap[:, self.config.num_keypoints:, :, :]
     # heatmap (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
     heatmap = predict.heatmap_from_predict(prob_pred, self.config.num_keypoints)
     _, _, heatmap_height, heatmap_width = heatmap.shape
     if device == 'cpu':
         coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_cpu(heatmap, self.config.num_keypoints)
     else:
         coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(heatmap, self.config.num_keypoints)
     depth_pred = predict.depth_integration(heatmap, depthmap_pred)
     xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)
     
     if enableKeypointPos == True:
         _out = self.mlp_2(x_feature_flatten)
         kp_pos_flatten = torch.flatten(xy_depth_pred, start_dim=1)
         _out = torch.cat((_out, kp_pos_flatten), dim=1)
         out = self.mlp_3(_out)
     else:
         out = self.mlp_1(x_feature_flatten)
     # gripper control
     out_r_6d = out[:, 0:6]
     out_r = compute_rotation_matrix_from_ortho6d(out_r_6d, device) # batch*3*3
     out_t = out[:, 6:9].view(-1,3) # batch*3*1
     out_step_size = torch.sigmoid(out[:, 9]).view(-1,1) # batch*1
     return xy_depth_pred, out_r, out_t, out_step_size
コード例 #5
0
def inference_resnet_nostage(
    network,  # type: resnet_nostage.ResnetNoStage
    imgproc_out,  # type: ImageProcOut
):  # type: (resnet_nostage.ResnetNoStage, ImageProcOut) -> np.ndarray
    """
    :param network: The network must be on GPU
    :param imgproc_out:
    :return: (3, n_keypoint) np array. (0:2, :) are x and y coords of keypoints in [-0.5, 0.5]
                                       (3, :) are the scaled depth
    """
    # Upload the image
    stacked_rgbd = torch.from_numpy(imgproc_out.stacked_rgbd)
    stacked_rgbd = torch.unsqueeze(stacked_rgbd, dim=0)
    stacked_rgbd = stacked_rgbd.cuda()

    # Do forward
    raw_pred = network(stacked_rgbd)
    num_keypoints = raw_pred.shape[1] // 2
    assert raw_pred.shape[1] == 2 * num_keypoints
    prob_pred = raw_pred[:, 0:num_keypoints, :, :]
    depthmap_pred = raw_pred[:, num_keypoints:, :, :]
    heatmap = prob_pred
    heatmap = predict.heatmap_from_predict(prob_pred, num_keypoints)
    #np.save('heatmap_0304.npy',heatmap.cpu().detach().numpy())
    # regression
    coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
        heatmap, num_keypoints)
    # coord_x (1,2,1)

    # argmax
    '''
    keypoint_xy_pred, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(heatmap)
    coord_x = keypoint_xy_pred[ :, :, 0:1]
    coord_y = keypoint_xy_pred[ :, :, 1:]
    '''

    depth_pred = predict.depth_integration(heatmap, depthmap_pred)

    # The scaled image coord and depth
    coord_x = coord_x.cpu().detach().numpy()
    coord_y = coord_y.cpu().detach().numpy()

    depth_pred = depth_pred.cpu().detach().numpy()

    # Combine them
    keypointxy_depth_pred = np.zeros((3, num_keypoints))
    keypointxy_depth_pred[0, :] = coord_x[0, :, 0]
    keypointxy_depth_pred[1, :] = coord_y[0, :, 0]
    keypointxy_depth_pred[2, :] = depth_pred[0, :, 0]
    return keypointxy_depth_pred
コード例 #6
0
    def test_2d_heatmap_cpu(self):
        n_batch = 16
        n_keypoint = 6
        image_size = 256

        # Some random pred
        rand_pred = torch.rand(size=(n_batch, n_keypoint, image_size,
                                     image_size))
        from mankey.network.predict import heatmap_from_predict
        heatmap = heatmap_from_predict(rand_pred, n_keypoint)

        # Check it
        for batch_idx in range(n_batch):
            for keypoint_idx in range(n_keypoint):
                prob_value = heatmap[batch_idx, keypoint_idx, :, :].sum()
                self.assertTrue(abs(float(prob_value.item()) - 1.0) < 1e-4)
コード例 #7
0
def inference_resnet_nostage_lstm(
    network,  # type: resnet_nostage.ResnetNoStage
    imgproc_out,  # type: ImageProcOut
    hidden  # type: tuple (h, c), hidden layer of lstm
):  # type: (resnet_nostage.ResnetNoStage, ImageProcOut) -> np.ndarray
    """
    :param network: The network must be on GPU
    :param imgproc_out:
    :return: (3, n_keypoint) np array. (0:2, :) are x and y coords of keypoints in [-0.5, 0.5]
                                       (3, :) are the scaled depth
    """
    # Upload the image
    stacked_rgbd = torch.from_numpy(imgproc_out.stacked_rgbd)
    stacked_rgbd = torch.unsqueeze(stacked_rgbd, dim=0)
    stacked_rgbd_sequence = stacked_rgbd.unsqueeze(0)

    h0, c0 = hidden
    raw_pred, (hn, cn) = network(stacked_rgbd_sequence, (h0, c0))

    num_keypoints = raw_pred.shape[1] // 2
    assert raw_pred.shape[1] == 2 * num_keypoints
    prob_pred = raw_pred[:, 0:num_keypoints, :, :]
    depthmap_pred = raw_pred[:, num_keypoints:, :, :]
    heatmap = predict.heatmap_from_predict(prob_pred, num_keypoints)
    coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_cpu(
        heatmap, num_keypoints)
    depth_pred = predict.depth_integration(heatmap, depthmap_pred)

    # The scaled image coord and depth
    coord_x = coord_x.cpu().detach().numpy()
    coord_y = coord_y.cpu().detach().numpy()
    depth_pred = depth_pred.cpu().detach().numpy()

    # Combine them
    keypointxy_depth_pred = np.zeros((3, num_keypoints))
    keypointxy_depth_pred[0, :] = coord_x[0, :, 0]
    keypointxy_depth_pred[1, :] = coord_y[0, :, 0]
    keypointxy_depth_pred[2, :] = depth_pred[0, :, 0]

    return keypointxy_depth_pred, (hn, cn)
コード例 #8
0
def inference_resnet_nostage(
    network,  # type: resnet_nostage.ResnetNoStage
    imgproc_out,  # type: ImageProcOut
):  # type: (resnet_nostage.ResnetNoStage, ImageProcOut) -> np.ndarray
    """
    :param network: The network must be on GPU
    :param imgproc_out:
    :return: (3, n_keypoint) np array. (0:2, :) are x and y coords of keypoints in [-0.5, 0.5]
                                       (3, :) are the scaled depth
    """
    # Upload the image
    stacked_rgbd = torch.from_numpy(imgproc_out.stacked_rgbd)
    stacked_rgbd = torch.unsqueeze(stacked_rgbd, dim=0)
    # stacked_rgbd = stacked_rgbd.cuda()

    # Do forward
    raw_pred = network(stacked_rgbd)
    torch.save(raw_pred, "/home/monti/Desktop/raw_pred_simple_net.tensor")

    num_keypoints = raw_pred.shape[1] // 2
    assert raw_pred.shape[1] == 2 * num_keypoints
    prob_pred = raw_pred[:, 0:num_keypoints, :, :]
    depthmap_pred = raw_pred[:, num_keypoints:, :, :]
    heatmap = predict.heatmap_from_predict(prob_pred, num_keypoints)
    coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_cpu(
        heatmap, num_keypoints)
    depth_pred = predict.depth_integration(heatmap, depthmap_pred)

    # The scaled image coord and depth
    coord_x = coord_x.cpu().detach().numpy()
    coord_y = coord_y.cpu().detach().numpy()
    depth_pred = depth_pred.cpu().detach().numpy()

    # Combine them
    keypointxy_depth_pred = np.zeros((3, num_keypoints))
    keypointxy_depth_pred[0, :] = coord_x[0, :, 0]
    keypointxy_depth_pred[1, :] = coord_y[0, :, 0]
    keypointxy_depth_pred[2, :] = depth_pred[0, :, 0]
    return keypointxy_depth_pred
コード例 #9
0
def visualize_entry_nostage(entry_idx: int, network: torch.nn.Module,
                            dataset: SupervisedKeypointDataset,
                            config: SupervisedKeypointDatasetConfig,
                            save_dir: str):
    # The raw input
    processed_entry = dataset.get_processed_entry(
        dataset.entry_list[entry_idx])

    # The processed input
    stacked_rgbd = dataset[entry_idx][parameter.rgbd_image_key]
    normalized_xy_depth = dataset[entry_idx][parameter.keypoint_xyd_key]

    stacked_rgbd = torch.from_numpy(stacked_rgbd)
    stacked_rgbd = torch.unsqueeze(stacked_rgbd, dim=0)
    stacked_rgbd = stacked_rgbd.cuda()

    # Do forward
    raw_pred = network(stacked_rgbd)
    prob_pred = raw_pred[:, 0:dataset.num_keypoints, :, :]
    depthmap_pred = raw_pred[:, dataset.num_keypoints:, :, :]
    heatmap = predict.heatmap_from_predict(prob_pred, dataset.num_keypoints)
    coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
        heatmap, dataset.num_keypoints)
    depth_pred = predict.depth_integration(heatmap, depthmap_pred)

    # To actual image coord
    coord_x = coord_x.cpu().detach().numpy()
    coord_y = coord_y.cpu().detach().numpy()
    coord_x = (coord_x + 0.5) * config.network_in_patch_width
    coord_y = (coord_y + 0.5) * config.network_in_patch_height

    # To actual depth value
    depth_pred = depth_pred.cpu().detach().numpy()
    depth_pred = (depth_pred *
                  config.depth_image_scale) + config.depth_image_mean

    # Combine them
    keypointxy_depth_pred = np.zeros((3, dataset.num_keypoints), dtype=np.int)
    keypointxy_depth_pred[0, :] = coord_x[0, :, 0].astype(np.int)
    keypointxy_depth_pred[1, :] = coord_y[0, :, 0].astype(np.int)
    keypointxy_depth_pred[2, :] = depth_pred[0, :, 0].astype(np.int)

    # Get the image
    from mankey.utils.imgproc import draw_image_keypoint, draw_visible_heatmap
    keypoint_rgb_cv = draw_image_keypoint(processed_entry.cropped_rgb,
                                          keypointxy_depth_pred,
                                          processed_entry.keypoint_validity)
    rgb_save_path = os.path.join(save_dir, 'image_%d_rgb.png' % entry_idx)
    cv2.imwrite(rgb_save_path, keypoint_rgb_cv)

    # The depth error
    depth_error_mm = np.abs(processed_entry.keypoint_xy_depth[2, :] -
                            keypointxy_depth_pred[2, :])
    max_depth_error = np.max(depth_error_mm)
    print('Entry %d' % entry_idx)
    print('The max depth error (mm) is ', max_depth_error)

    # The pixel error
    pixel_error = np.sum(np.sqrt((processed_entry.keypoint_xy_depth[0:2, :] -
                                  keypointxy_depth_pred[0:2, :])**2),
                         axis=0)
    max_pixel_error = np.max(pixel_error)
    print('The max pixel error (pixel in 256x256 image) is ', max_pixel_error)
コード例 #10
0
def train(checkpoint_dir: str, start_from_ckpnt="", save_epoch_offset=0):
    global training_data_path, validation_data_path, learning_rate, n_epoch, segmented_img_size

    time_sequence_length = 5
    dataset_train, train_config = construct_dataset(
        True, training_data_path, time_sequence_length=time_sequence_length)
    dataset_val, val_config = construct_dataset(
        False, validation_data_path, time_sequence_length=time_sequence_length)

    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=4,
                              shuffle=True,
                              num_workers=4)
    loader_val = DataLoader(dataset=dataset_val,
                            batch_size=4,
                            shuffle=False,
                            num_workers=1)

    network, net_config = construct_network(True)
    if start_from_ckpnt:
        # strict=False will allow the model to load the parameters from a
        # previously trained model that didn't use the LSTM extension. The
        # loaded model was still trained on keypoint detection for the same
        # object class but only learned on images without occluded keypoints.
        network.load_state_dict(torch.load(start_from_ckpnt), strict=False)
    else:
        init_from_modelzoo(network, net_config)
    network.cuda()

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 90],
                                                     gamma=0.1)

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 10 == 0 and epoch > 0:
            file_name = f"checkpoint_{(epoch + save_epoch_offset):04d}.pth"
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0
        train_error_depth = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]

            # Those have the shape [batch_size, seq_length, ...]
            keypoint_xy_depth = merge_first_two_dims(
                data[parameter.keypoint_xyd_key])
            keypoint_weight = merge_first_two_dims(
                data[parameter.keypoint_validity_key])

            # Upload to GPU
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()

            # To predict
            optimizer.zero_grad()
            # don't care about hidden states
            raw_pred, _ = network(image)
            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:, :, :]
            heatmap = predict.heatmap_from_predict(prob_pred,
                                                   net_config.num_keypoints)
            _, _, heatmap_height, heatmap_width = heatmap.shape

            # Compute the coordinate
            coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
                heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(heatmap, depthmap_pred)

            # Concantate them
            xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)

            # Compute loss
            loss = sequence_symmetric_loss_l1(xy_depth_pred,
                                              keypoint_xy_depth,
                                              keypoint_weight,
                                              sequence_length=5)
            loss.backward()
            optimizer.step()

            # cleanup
            del loss

            # Log info
            xy_loss, depth_loss = symmetric_loss_l1_separate(
                xy_depth_pred, keypoint_xy_depth, keypoint_weight)

            xy_error = float(xy_loss.item())
            depth_error = float(depth_loss.item())
            if idx % 400 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                s = segmented_img_size
                print(f"The averaged pixel error is (pixel in {s}x{s} image):",
                      segmented_img_size * xy_error / len(xy_depth_pred))
                print(
                    "The averaged depth error is (mm):",
                    train_config.depth_image_scale * depth_error /
                    len(xy_depth_pred))

            # Update info
            train_error_xy += float(xy_error)
            train_error_depth += float(depth_error)

        # The info at epoch level
        print('Epoch %d' % epoch)
        s = segmented_img_size
        sample_count = len(dataset_train) * 5
        print(
            f"The training averaged pixel error is (pixel in {s}x{s} image):",
            s * train_error_xy / sample_count)
        print(
            'The training averaged depth error is (mm): ',
            train_config.depth_image_scale * train_error_depth / sample_count)

        # Evaluation for epochs
        network.eval()
        val_error_xy = 0
        val_error_depth = 0

        # The validation iteration of the data
        for idx, data in enumerate(loader_val):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = merge_first_two_dims(
                data[parameter.keypoint_xyd_key])
            keypoint_weight = merge_first_two_dims(
                data[parameter.keypoint_validity_key])

            # Upload to GPU
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()

            # Predict. Don't care about hidden states because the input data is already a sequence.
            raw_pred, _ = network(image)

            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:, :, :]
            heatmap = predict.heatmap_from_predict(prob_pred,
                                                   net_config.num_keypoints)
            _, _, heatmap_height, heatmap_width = heatmap.shape

            # Compute the coordinate
            coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
                heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(heatmap, depthmap_pred)

            # Concantate them
            xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)

            # Log info
            xy_loss, depth_loss = symmetric_loss_l1_separate(
                xy_depth_pred, keypoint_xy_depth, keypoint_weight)

            # Update info
            val_error_xy += float(xy_error)
            val_error_depth += float(depth_error)

        print('Validation for epoch', epoch)
        s = segmented_img_size
        sample_count = len(dataset_val) * 5
        print(f"The averaged pixel error is (pixel in {s}x{s} image):",
              s * val_error_xy / sample_count)
        print("The averaged depth error is (mm):",
              val_config.depth_image_scale * val_error_depth / sample_count)

        print("-" * 20)
コード例 #11
0
def train(checkpoint_dir: str, start_from_ckpnt="", save_epoch_offset=0):
    global training_data_path, validation_data_path, learning_rate, n_epoch, segmented_img_size

    dataset_train, train_config = construct_dataset(True, training_data_path)
    dataset_val, val_config = construct_dataset(False, validation_data_path)

    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=16,
                              shuffle=True,
                              num_workers=16)
    loader_val = DataLoader(dataset=dataset_val,
                            batch_size=16,
                            shuffle=False,
                            num_workers=4)

    network, net_config = construct_network()
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))
    else:
        init_from_modelzoo(network, net_config)
    network.cuda()

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 90],
                                                     gamma=0.1)

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 5 == 0 and epoch > 0:
            file_name = f"checkpoint_{(epoch + save_epoch_offset):04d}.pth"
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0
        train_error_depth = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]

            # Upload to GPU
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()

            # To predict
            optimizer.zero_grad()
            raw_pred = network(image)
            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:, :, :]
            heatmap = predict.heatmap_from_predict(prob_pred,
                                                   net_config.num_keypoints)
            _, _, heatmap_height, heatmap_width = heatmap.shape

            # Compute the coordinate
            coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
                heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(heatmap, depthmap_pred)

            # Concantate them
            xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)

            # Compute loss
            loss = symmetric_loss_l1(xy_depth_pred, keypoint_xy_depth,
                                     keypoint_weight)
            loss.backward()
            optimizer.step()

            # cleanup
            del loss

            # Log info
            xy_loss, depth_loss = symmetric_loss_l1_separate(
                xy_depth_pred, keypoint_xy_depth, keypoint_weight)

            xy_error = float(xy_loss.item())
            depth_error = float(depth_loss.item())
            if idx % 400 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                s = segmented_img_size
                print(f"The averaged pixel error is (pixel in {s}x{s} image):",
                      segmented_img_size * xy_error / len(xy_depth_pred))
                print(
                    "The averaged depth error is (mm):",
                    train_config.depth_image_scale * depth_error /
                    len(xy_depth_pred))

            # Update info
            train_error_xy += float(xy_error)
            train_error_depth += float(depth_error)

        # The info at epoch level
        print('Epoch %d' % epoch)
        s = segmented_img_size
        print(
            f"The training averaged pixel error is (pixel in {s}x{s} image):",
            s * train_error_xy / len(dataset_train))
        print(
            'The training averaged depth error is (mm): ',
            train_config.depth_image_scale * train_error_depth /
            len(dataset_train))

        # Evaluation for epochs
        network.eval()
        val_error_xy = 0
        val_error_depth = 0

        # The validation iteration of the data
        for idx, data in enumerate(loader_val):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]

            # Upload to GPU
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()

            # To predict
            raw_pred = network(image)

            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:, :, :]
            heatmap = predict.heatmap_from_predict(prob_pred,
                                                   net_config.num_keypoints)
            _, _, heatmap_height, heatmap_width = heatmap.shape

            # Compute the coordinate
            coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
                heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(heatmap, depthmap_pred)

            # Concantate them
            xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)

            # Log info
            xy_loss, depth_loss = symmetric_loss_l1_separate(
                xy_depth_pred, keypoint_xy_depth, keypoint_weight)

            # Update info
            val_error_xy += float(xy_error)
            val_error_depth += float(depth_error)

        print('Validation for epoch', epoch)
        s = segmented_img_size
        print(f"The averaged pixel error is (pixel in {s}x{s} image):",
              s * val_error_xy / len(dataset_val))
        print(
            "The averaged depth error is (mm):",
            val_config.depth_image_scale * val_error_depth / len(dataset_val))

        print("-" * 20)
コード例 #12
0
def train(checkpoint_dir: str,
          start_from_ckpnt: str = '',
          save_epoch_offset: int = 0):
    # Construct the dataset
    dataset_train, train_config = construct_dataset(is_train=True)
    # dataset_val, val_config = construct_dataset(is_train=False)

    # And the dataloader
    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=64,
                              shuffle=True,
                              num_workers=4)
    # loader_val = DataLoader(dataset=dataset_val, batch_size=16, shuffle=False, num_workers=4)

    # Construct the regressor
    network, net_config = construct_network()
    #control_network = ControlNetwork(in_channel=int(net_config.num_keypoints * net_config.depth_per_keypoint * 256/4 * 256/4))
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))
    else:
        init_from_modelzoo(network, net_config)
    network.to(device)
    #control_network.to(device)

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # root mean square error loss
    criterion_rmse = RMSELoss()
    criterion_cos = torch.nn.CosineSimilarity(dim=1)
    criterion_bce = torch.nn.BCELoss(reduction='none')
    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 90],
                                                     gamma=0.1)

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 20 == 0 and epoch > 0:
            file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset)
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0
        train_error_depth = 0
        train_error_move = 0
        train_error_rot = 0
        train_error_xyz = 0
        train_error_step = 0
        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]
            delta_rot = data[parameter.delta_rot_key]
            delta_xyz = data[parameter.delta_xyz_key]
            gripper_pose = data[parameter.gripper_pose_key]
            step_size = data[parameter.step_size_key]

            # Upload to GPU
            image = image.to(device)
            keypoint_xy_depth = keypoint_xy_depth.to(device)
            keypoint_weight = keypoint_weight.to(device)
            delta_rot = delta_rot.to(device)
            delta_xyz = delta_xyz.to(device)
            gripper_pose = gripper_pose.to(device)
            step_size = step_size.to(device)
            #print('delta_rot',delta_rot.shape)
            #print('delta_xyz',delta_xyz.shape)
            #print('gripper_pose',gripper_pose.shape)
            #print('step_size',step_size.shape)
            # To predict
            optimizer.zero_grad()

            # raw_pred (batch_size, num_keypoint*2, network_out_map_height, network_out_map_width)
            # prob_pred (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            # depthmap_pred (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            raw_pred, delta_rot_pred, delta_xyz_pred, step_size_pred = network(
                image,
                gripper_pose,
                device,
                enableGripperPose=enableGripperPose)
            #print((1-criterion_cos(torch.tensor([[0.01,0.01,0.01],[0.01,0.01,0.01]]).to(device), torch.tensor([[0.0,0.0,0.0],[0.0,0.0,0.0]]).to(device))).mean())
            #gripper control network
            #raw_pred_flatten = torch.flatten(raw_pred, start_dim=1)
            #delta_rot_pred, delta_xyz_pred, step_size_pred = control_network(raw_pred_flatten)
            loss_r = criterion_rmse(delta_rot_pred, delta_rot)
            #loss_t = (1-criterion_cos(delta_xyz_pred, delta_xyz)).mean() + criterion_rmse(delta_xyz_pred, delta_xyz)
            loss_t = criterion_rmse(delta_xyz_pred, delta_xyz)
            loss_s = criterion_bce(step_size_pred, step_size)
            loss_s = loss_s.mean()
            loss_move = loss_r * 10 + loss_t * 10 + loss_s

            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:, :, :]
            # heatmap (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            heatmap = predict.heatmap_from_predict(prob_pred,
                                                   net_config.num_keypoints)
            _, _, heatmap_height, heatmap_width = heatmap.shape
            #print(raw_pred.shape)
            #print(prob_pred.shape)
            #print(depthmap_pred.shape)
            # Compute the coordinate
            if device == 'cpu':
                coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_cpu(
                    heatmap, net_config.num_keypoints)
            else:
                coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
                    heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(heatmap, depthmap_pred)
            # Concantate them
            xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)

            # Compute loss
            loss_kpt = weighted_l1_loss(xy_depth_pred, keypoint_xy_depth,
                                        keypoint_weight)
            loss = loss_kpt + loss_move
            loss.backward()
            optimizer.step()

            # Log info
            xy_error = float(
                weighted_l1_loss(xy_depth_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            depth_error = float(
                weighted_l1_loss(xy_depth_pred[:, :,
                                               2], keypoint_xy_depth[:, :, 2],
                                 keypoint_weight[:, :, 2]).item())
            '''
            if idx % 100 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                print('The averaged pixel error is (pixel in 256x256 image): ', 256 * xy_error / len(xy_depth_pred))
                print('The averaged depth error is (mm): ', 256 * depth_error / len(xy_depth_pred))
                print('The move error is', loss_move.item())
            '''
            # Update info
            train_error_xy += float(xy_error)
            train_error_depth += float(depth_error)
            train_error_move += float(loss_move)
            train_error_rot += float(loss_r)
            train_error_xyz += float(loss_t)
            train_error_step += float(loss_s)
            # cleanup
            del loss

        # The info at epoch level
        print('Epoch %d' % epoch)
        print(
            'The training averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy / len(dataset_train))
        print(
            'The training averaged depth error is (mm): ',
            train_config.depth_image_scale * train_error_depth /
            len(dataset_train))
        #print('The training averaged move error is: ', train_error_move / len(dataset_train))
        print('The training averaged rot error is: ',
              train_error_rot / len(dataset_train))
        print('The training averaged xyz error is: ',
              train_error_xyz / len(dataset_train))
        print('The training averaged step error is: ',
              train_error_step / len(dataset_train))
        writer.add_scalar('average pixel error',
                          256 * train_error_xy / len(dataset_train), epoch)
        writer.add_scalar(
            'average depth error', train_config.depth_image_scale *
            train_error_depth / len(dataset_train), epoch)
        writer.add_scalar('average move error',
                          train_error_move / len(dataset_train), epoch)
        writer.add_scalar('average rot error',
                          train_error_rot / len(dataset_train), epoch)
        writer.add_scalar('average xyz error',
                          train_error_xyz / len(dataset_train), epoch)
        writer.add_scalar('average step error',
                          train_error_step / len(dataset_train), epoch)
    writer.close()
コード例 #13
0
def train(checkpoint_dir: str,
          start_from_ckpnt: str = '',
          save_epoch_offset: int = 0):
    # Construct the dataset
    dataset_train, train_config = construct_dataset(is_train=True)
    # dataset_val, val_config = construct_dataset(is_train=False)

    # And the dataloader
    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=64,
                              shuffle=True,
                              num_workers=4)
    # loader_val = DataLoader(dataset=dataset_val, batch_size=16, shuffle=False, num_workers=4)

    # Construct the regressor
    network, net_config = construct_network()
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))
    else:
        init_from_modelzoo(network, net_config)
    network.cuda()

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 90],
                                                     gamma=0.1)

    # The loss for heatmap
    heatmap_criterion = torch.nn.MSELoss().cuda()
    #heatmap_criterion = torch.nn.KLDivLoss().cuda()

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 20 == 0 and epoch > 0:
            file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset)
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0
        train_error_depth = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            #keypoint_xy_depth (batch_size, num_keypoint, xydepth)
            keypoint_weight = data[parameter.keypoint_validity_key]
            target_heatmap = data[parameter.target_heatmap_key]

            #if idx  == 0:
            #    np.save('rgbd.npy', image[0])
            #    np.save('target_heatmap.npy', target_heatmap[0])
            # Upload to GPU
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()
            target_heatmap = target_heatmap.cuda()

            # To predict
            optimizer.zero_grad()

            # raw_pred (batch_size, num_keypoint*2, network_out_map_height, network_out_map_width)
            # prob_pred (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            # depthmap_pred (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            raw_pred = network(image)
            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:, :, :]
            # heatmap (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            heatmap = predict.heatmap_from_predict(prob_pred,
                                                   net_config.num_keypoints)
            #heatmap = prob_pred
            _, _, heatmap_height, heatmap_width = heatmap.shape
            #print(raw_pred.shape)
            #print(prob_pred.shape)
            #print(depthmap_pred.shape)
            # Compute the coordinate
            #coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(heatmap, depthmap_pred)
            depth_pred = depth_pred[:, :, 0]
            # Concantate them
            #xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)

            # Compute loss
            depth_loss = weighted_l1_loss(depth_pred, keypoint_xy_depth[:, :,
                                                                        2],
                                          keypoint_weight[:, :, 2])
            heatmap_loss = heatmap_criterion(heatmap, target_heatmap)
            np.save('pred_heatmap.npy', heatmap.cpu().detach().numpy())
            np.save('target_heatmap.npy',
                    target_heatmap.cpu().detach().numpy())
            if idx % 100 == 0:
                print('depth loss:', depth_loss)
                print('heatmap loss:', heatmap_loss)
            #loss = weighted_l1_loss(xy_depth_pred, keypoint_xy_depth, keypoint_weight)
            #loss = depth_loss + 1500*heatmap_loss
            loss = heatmap_loss
            loss.backward()
            optimizer.step()

            # cleanup
            del loss

            # Log info
            #xy_error = float(weighted_l1_loss(xy_depth_pred[:, :, 0:2], keypoint_xy_depth[:, :, 0:2], keypoint_weight[:, :, 0:2]).item())
            #depth_error = float(weighted_l1_loss(xy_depth_pred[:, :, 2], keypoint_xy_depth[:, :, 2], keypoint_weight[:, :, 2]).item())
            keypoint_xy_pred, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(
                heatmap)
            xy_error = float(
                weighted_l1_loss(keypoint_xy_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            depth_error = float(
                weighted_l1_loss(depth_pred, keypoint_xy_depth[:, :, 2],
                                 keypoint_weight[:, :, 2]).item())
            if idx % 100 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                print('The averaged pixel error is (pixel in 256x256 image): ',
                      256 * xy_error / image.shape[0])
                print(
                    'The averaged depth error is (mm): ',
                    train_config.depth_image_scale * depth_error /
                    len(depth_pred))

            # Update info
            train_error_xy += float(xy_error)
            train_error_depth += float(depth_error)

        # The info at epoch level
        print('Epoch %d' % epoch)
        print(
            'The training averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy / len(dataset_train))
        print(
            'The training averaged depth error is (mm): ',
            train_config.depth_image_scale * train_error_depth /
            len(dataset_train))
        writer.add_scalar('average pixel error',
                          256 * train_error_xy / len(dataset_train), epoch)
        writer.add_scalar(
            'average depth error', train_config.depth_image_scale *
            train_error_depth / len(dataset_train), epoch)
    writer.close()
コード例 #14
0
def train(checkpoint_dir: str,
          start_from_ckpnt: str = '',
          save_epoch_offset: int = 0):
    # Construct the dataset
    dataset_train, train_config = construct_dataset(is_train=True)
    # dataset_val, val_config = construct_dataset(is_train=False)

    # And the dataloader
    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=8,
                              shuffle=True,
                              num_workers=4)
    # loader_val = DataLoader(dataset=dataset_val, batch_size=16, shuffle=False, num_workers=4)

    # Construct the regressor
    network, net_config = construct_network()

    # To cuda
    network = torch.nn.DataParallel(network).cuda()
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 80],
                                                     gamma=0.1)

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 2 == 0 and epoch > 0:
            file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset)
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0
        train_error_depth = 0
        train_error_xy_internal = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]

            # Move to gpu
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()

            # To predict
            optimizer.zero_grad()
            raw_pred = network(image)

            # The last stage
            raw_pred_last = raw_pred[-1]
            prob_pred_last = raw_pred_last[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred_last = raw_pred_last[:,
                                               net_config.num_keypoints:, :, :]
            heatmap_last = predict.heatmap_from_predict(
                prob_pred_last, net_config.num_keypoints)
            _, _, heatmap_height, heatmap_width = heatmap_last.shape

            # Compute the coordinate
            coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
                heatmap_last, net_config.num_keypoints)
            depth_pred = predict.depth_integration(heatmap_last,
                                                   depthmap_pred_last)

            # Concantate them
            xy_depth_pred = torch.cat([coord_x, coord_y, depth_pred], dim=2)

            # Compute loss
            loss = weighted_l1_loss(xy_depth_pred, keypoint_xy_depth,
                                    keypoint_weight)

            # For all other layers
            for stage_i in range(len(raw_pred) - 1):
                prob_pred_i = raw_pred[
                    stage_i]  # Only 2d prediction on previous layers
                assert prob_pred_i.shape == prob_pred_last.shape
                heatmap_i = predict.heatmap_from_predict(
                    prob_pred_i, net_config.num_keypoints)
                coord_x_i, coord_y_i = predict.heatmap2d_to_normalized_imgcoord_gpu(
                    heatmap_i, net_config.num_keypoints)
                xy_pred_i = torch.cat([coord_x_i, coord_y_i], dim=2)
                loss = loss + weighted_l1_loss(xy_pred_i,
                                               keypoint_xy_depth[:, :, 0:2],
                                               keypoint_weight[:, :, 0:2])

            # The SGD step
            loss.backward()
            optimizer.step()
            del loss

            # Log info
            xy_error = float(
                weighted_l1_loss(xy_depth_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            depth_error = float(
                weighted_l1_loss(xy_depth_pred[:, :,
                                               2], keypoint_xy_depth[:, :, 2],
                                 keypoint_weight[:, :, 2]).item())
            # The error of internal stage
            keypoint_xy_pred_internal, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(
                raw_pred[0])
            xy_error_internal = float(
                weighted_l1_loss(keypoint_xy_pred_internal[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            if idx % 100 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                print('The averaged pixel error is (pixel in 256x256 image): ',
                      256 * xy_error / len(xy_depth_pred))
                print('The averaged depth error is (mm): ',
                      256 * depth_error / len(xy_depth_pred))
                print(
                    'The averaged internal pixel error is (pixel in 256x256 image): ',
                    256 * xy_error_internal / image.shape[0])

            # Update info
            train_error_xy += float(xy_error)
            train_error_depth += float(depth_error)
            train_error_xy_internal += float(xy_error_internal)

        # The info at epoch level
        print('Epoch %d' % epoch)
        print(
            'The training averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy / len(dataset_train))
        print(
            'The training averaged depth error is (mm): ',
            train_config.depth_image_scale * train_error_depth /
            len(dataset_train))
        print(
            'The training internal averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy_internal / len(dataset_train))
def train(checkpoint_dir: str,
          start_from_ckpnt: str = '',
          save_epoch_offset: int = 0):
    # Construct the dataset
    dataset_train, train_config = construct_dataset(is_train=True)

    # And the dataloader
    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=32,
                              shuffle=True,
                              num_workers=4)

    # Construct the regressor
    network, net_config = construct_network()
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))
    else:
        init_from_modelzoo(network, net_config)
    network.cuda()

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 90],
                                                     gamma=0.1)

    # The loss for heatmap
    heatmap_criterion = torch.nn.MSELoss().cuda()

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 4 == 0 and epoch > 0:
            file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset)
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0
        train_error_depth = 0
        train_error_xy_heatmap = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]
            target_heatmap = data[parameter.target_heatmap_key]

            # Upload to cuda
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()
            target_heatmap = target_heatmap.cuda()

            # To predict
            optimizer.zero_grad()
            raw_pred = network(image)
            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:2 *
                                     net_config.num_keypoints, :, :]
            regress_heatmap = raw_pred[:, 2 * net_config.num_keypoints:, :, :]
            integral_heatmap = predict.heatmap_from_predict(
                prob_pred, net_config.num_keypoints)
            _, _, heatmap_height, heatmap_width = integral_heatmap.shape

            # Compute the coordinate
            coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
                integral_heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(integral_heatmap,
                                                   depthmap_pred)

            # Concantate them
            xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)

            # Compute loss
            loss = weighted_mse_loss(xy_depth_pred, keypoint_xy_depth,
                                     keypoint_weight)
            loss = loss + heatmap_loss_weight * heatmap_criterion(
                regress_heatmap, target_heatmap)

            # Do update
            loss.backward()
            optimizer.step()

            # Log info
            xy_error = float(
                weighted_l1_loss(xy_depth_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            depth_error = float(
                weighted_l1_loss(xy_depth_pred[:, :,
                                               2], keypoint_xy_depth[:, :, 2],
                                 keypoint_weight[:, :, 2]).item())
            keypoint_xy_pred_heatmap, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(
                regress_heatmap)
            xy_error_heatmap = float(
                weighted_l1_loss(keypoint_xy_pred_heatmap[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            if idx % 100 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                print('The averaged pixel error is (pixel in 256x256 image): ',
                      256 * xy_error / len(xy_depth_pred))
                print(
                    'The averaged depth error is (mm): ',
                    train_config.depth_image_scale * depth_error /
                    len(xy_depth_pred))
                print(
                    'The averaged heatmap argmax pixel error is (pixel in 256x256 image): ',
                    256 * xy_error_heatmap / len(xy_depth_pred))

            # Update info
            train_error_xy += float(xy_error)
            train_error_depth += float(depth_error)
            train_error_xy_heatmap += float(xy_error_heatmap)

        # The info at epoch level
        print('Epoch %d' % epoch)
        print(
            'The training averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy / len(dataset_train))
        print(
            'The training averaged depth error is (mm): ',
            train_config.depth_image_scale * train_error_depth /
            len(dataset_train))
        print(
            'The training averaged heatmap pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy_heatmap / len(dataset_train))