def test_output_size(self):
        from mankey.network.resnet_nostage import ResnetNoStageConfig, ResnetNoStage, init_from_modelzoo
        config = ResnetNoStageConfig()
        config.num_layers = 50
        config.num_keypoints = 10
        config.depth_per_keypoint = 1
        config.image_channels = 4
        net = ResnetNoStage(config)

        # Load from model zoo
        init_from_modelzoo(net, config)

        # Test on some dymmy image
        batch_size = 10
        img = torch.zeros((batch_size, config.image_channels, 256, 256))
        out = net(img)

        # Check it
        self.assertEqual(out.shape[0], batch_size)
        self.assertEqual(out.shape[1], config.num_keypoints * config.depth_per_keypoint)
        self.assertEqual(out.shape[2], 256 / 4)
        self.assertEqual(out.shape[3], 256 / 4)
Ejemplo n.º 2
0
def train(checkpoint_dir: str, start_from_ckpnt="", save_epoch_offset=0):
    global training_data_path, validation_data_path, learning_rate, n_epoch, segmented_img_size

    time_sequence_length = 5
    dataset_train, train_config = construct_dataset(
        True, training_data_path, time_sequence_length=time_sequence_length)
    dataset_val, val_config = construct_dataset(
        False, validation_data_path, time_sequence_length=time_sequence_length)

    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=4,
                              shuffle=True,
                              num_workers=4)
    loader_val = DataLoader(dataset=dataset_val,
                            batch_size=4,
                            shuffle=False,
                            num_workers=1)

    network, net_config = construct_network(True)
    if start_from_ckpnt:
        # strict=False will allow the model to load the parameters from a
        # previously trained model that didn't use the LSTM extension. The
        # loaded model was still trained on keypoint detection for the same
        # object class but only learned on images without occluded keypoints.
        network.load_state_dict(torch.load(start_from_ckpnt), strict=False)
    else:
        init_from_modelzoo(network, net_config)
    network.cuda()

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 90],
                                                     gamma=0.1)

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 10 == 0 and epoch > 0:
            file_name = f"checkpoint_{(epoch + save_epoch_offset):04d}.pth"
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0
        train_error_depth = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]

            # Those have the shape [batch_size, seq_length, ...]
            keypoint_xy_depth = merge_first_two_dims(
                data[parameter.keypoint_xyd_key])
            keypoint_weight = merge_first_two_dims(
                data[parameter.keypoint_validity_key])

            # Upload to GPU
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()

            # To predict
            optimizer.zero_grad()
            # don't care about hidden states
            raw_pred, _ = network(image)
            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:, :, :]
            heatmap = predict.heatmap_from_predict(prob_pred,
                                                   net_config.num_keypoints)
            _, _, heatmap_height, heatmap_width = heatmap.shape

            # Compute the coordinate
            coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
                heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(heatmap, depthmap_pred)

            # Concantate them
            xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)

            # Compute loss
            loss = sequence_symmetric_loss_l1(xy_depth_pred,
                                              keypoint_xy_depth,
                                              keypoint_weight,
                                              sequence_length=5)
            loss.backward()
            optimizer.step()

            # cleanup
            del loss

            # Log info
            xy_loss, depth_loss = symmetric_loss_l1_separate(
                xy_depth_pred, keypoint_xy_depth, keypoint_weight)

            xy_error = float(xy_loss.item())
            depth_error = float(depth_loss.item())
            if idx % 400 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                s = segmented_img_size
                print(f"The averaged pixel error is (pixel in {s}x{s} image):",
                      segmented_img_size * xy_error / len(xy_depth_pred))
                print(
                    "The averaged depth error is (mm):",
                    train_config.depth_image_scale * depth_error /
                    len(xy_depth_pred))

            # Update info
            train_error_xy += float(xy_error)
            train_error_depth += float(depth_error)

        # The info at epoch level
        print('Epoch %d' % epoch)
        s = segmented_img_size
        sample_count = len(dataset_train) * 5
        print(
            f"The training averaged pixel error is (pixel in {s}x{s} image):",
            s * train_error_xy / sample_count)
        print(
            'The training averaged depth error is (mm): ',
            train_config.depth_image_scale * train_error_depth / sample_count)

        # Evaluation for epochs
        network.eval()
        val_error_xy = 0
        val_error_depth = 0

        # The validation iteration of the data
        for idx, data in enumerate(loader_val):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = merge_first_two_dims(
                data[parameter.keypoint_xyd_key])
            keypoint_weight = merge_first_two_dims(
                data[parameter.keypoint_validity_key])

            # Upload to GPU
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()

            # Predict. Don't care about hidden states because the input data is already a sequence.
            raw_pred, _ = network(image)

            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:, :, :]
            heatmap = predict.heatmap_from_predict(prob_pred,
                                                   net_config.num_keypoints)
            _, _, heatmap_height, heatmap_width = heatmap.shape

            # Compute the coordinate
            coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
                heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(heatmap, depthmap_pred)

            # Concantate them
            xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)

            # Log info
            xy_loss, depth_loss = symmetric_loss_l1_separate(
                xy_depth_pred, keypoint_xy_depth, keypoint_weight)

            # Update info
            val_error_xy += float(xy_error)
            val_error_depth += float(depth_error)

        print('Validation for epoch', epoch)
        s = segmented_img_size
        sample_count = len(dataset_val) * 5
        print(f"The averaged pixel error is (pixel in {s}x{s} image):",
              s * val_error_xy / sample_count)
        print("The averaged depth error is (mm):",
              val_config.depth_image_scale * val_error_depth / sample_count)

        print("-" * 20)
Ejemplo n.º 3
0
            training_data_path,
            time_sequence_length=time_sequence_length)
        dataset_val, val_config = construct_dataset(
            False,
            validation_data_path,
            time_sequence_length=time_sequence_length)

        loader_train = DataLoader(dataset=dataset_train,
                                  batch_size=4,
                                  shuffle=True,
                                  num_workers=4)
        loader_val = DataLoader(dataset=dataset_val,
                                batch_size=4,
                                shuffle=False,
                                num_workers=1)

        # Construct the regressor
        network, net_config = construct_network(True)
        start_from_ckpnt = False
        if start_from_ckpnt:
            network.load_state_dict(torch.load(start_from_ckpnt))
        else:
            init_from_modelzoo(network, net_config)
        # network.cuda()

        for data in loader_val:
            d = data
            break

        image = d["rgbd_image"]
Ejemplo n.º 4
0
def train(checkpoint_dir: str, start_from_ckpnt="", save_epoch_offset=0):
    global training_data_path, validation_data_path, learning_rate, n_epoch, segmented_img_size

    dataset_train, train_config = construct_dataset(True, training_data_path)
    dataset_val, val_config = construct_dataset(False, validation_data_path)

    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=16,
                              shuffle=True,
                              num_workers=16)
    loader_val = DataLoader(dataset=dataset_val,
                            batch_size=16,
                            shuffle=False,
                            num_workers=4)

    network, net_config = construct_network()
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))
    else:
        init_from_modelzoo(network, net_config)
    network.cuda()

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 90],
                                                     gamma=0.1)

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 5 == 0 and epoch > 0:
            file_name = f"checkpoint_{(epoch + save_epoch_offset):04d}.pth"
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0
        train_error_depth = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]

            # Upload to GPU
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()

            # To predict
            optimizer.zero_grad()
            raw_pred = network(image)
            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:, :, :]
            heatmap = predict.heatmap_from_predict(prob_pred,
                                                   net_config.num_keypoints)
            _, _, heatmap_height, heatmap_width = heatmap.shape

            # Compute the coordinate
            coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
                heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(heatmap, depthmap_pred)

            # Concantate them
            xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)

            # Compute loss
            loss = symmetric_loss_l1(xy_depth_pred, keypoint_xy_depth,
                                     keypoint_weight)
            loss.backward()
            optimizer.step()

            # cleanup
            del loss

            # Log info
            xy_loss, depth_loss = symmetric_loss_l1_separate(
                xy_depth_pred, keypoint_xy_depth, keypoint_weight)

            xy_error = float(xy_loss.item())
            depth_error = float(depth_loss.item())
            if idx % 400 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                s = segmented_img_size
                print(f"The averaged pixel error is (pixel in {s}x{s} image):",
                      segmented_img_size * xy_error / len(xy_depth_pred))
                print(
                    "The averaged depth error is (mm):",
                    train_config.depth_image_scale * depth_error /
                    len(xy_depth_pred))

            # Update info
            train_error_xy += float(xy_error)
            train_error_depth += float(depth_error)

        # The info at epoch level
        print('Epoch %d' % epoch)
        s = segmented_img_size
        print(
            f"The training averaged pixel error is (pixel in {s}x{s} image):",
            s * train_error_xy / len(dataset_train))
        print(
            'The training averaged depth error is (mm): ',
            train_config.depth_image_scale * train_error_depth /
            len(dataset_train))

        # Evaluation for epochs
        network.eval()
        val_error_xy = 0
        val_error_depth = 0

        # The validation iteration of the data
        for idx, data in enumerate(loader_val):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]

            # Upload to GPU
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()

            # To predict
            raw_pred = network(image)

            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:, :, :]
            heatmap = predict.heatmap_from_predict(prob_pred,
                                                   net_config.num_keypoints)
            _, _, heatmap_height, heatmap_width = heatmap.shape

            # Compute the coordinate
            coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
                heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(heatmap, depthmap_pred)

            # Concantate them
            xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)

            # Log info
            xy_loss, depth_loss = symmetric_loss_l1_separate(
                xy_depth_pred, keypoint_xy_depth, keypoint_weight)

            # Update info
            val_error_xy += float(xy_error)
            val_error_depth += float(depth_error)

        print('Validation for epoch', epoch)
        s = segmented_img_size
        print(f"The averaged pixel error is (pixel in {s}x{s} image):",
              s * val_error_xy / len(dataset_val))
        print(
            "The averaged depth error is (mm):",
            val_config.depth_image_scale * val_error_depth / len(dataset_val))

        print("-" * 20)
Ejemplo n.º 5
0
def train(checkpoint_dir: str,
          start_from_ckpnt: str = '',
          save_epoch_offset: int = 0):
    # Construct the dataset
    dataset_train, train_config = construct_dataset(is_train=True)
    # dataset_val, val_config = construct_dataset(is_train=False)

    # And the dataloader
    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=64,
                              shuffle=True,
                              num_workers=4)
    # loader_val = DataLoader(dataset=dataset_val, batch_size=16, shuffle=False, num_workers=4)

    # Construct the regressor
    network, net_config = construct_network()
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))
    else:
        init_from_modelzoo(network, net_config)
    network.cuda()

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 90],
                                                     gamma=0.1)

    # The loss for heatmap
    heatmap_criterion = torch.nn.MSELoss().cuda()
    #heatmap_criterion = torch.nn.KLDivLoss().cuda()

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 20 == 0 and epoch > 0:
            file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset)
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0
        train_error_depth = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            #keypoint_xy_depth (batch_size, num_keypoint, xydepth)
            keypoint_weight = data[parameter.keypoint_validity_key]
            target_heatmap = data[parameter.target_heatmap_key]

            #if idx  == 0:
            #    np.save('rgbd.npy', image[0])
            #    np.save('target_heatmap.npy', target_heatmap[0])
            # Upload to GPU
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()
            target_heatmap = target_heatmap.cuda()

            # To predict
            optimizer.zero_grad()

            # raw_pred (batch_size, num_keypoint*2, network_out_map_height, network_out_map_width)
            # prob_pred (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            # depthmap_pred (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            raw_pred = network(image)
            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:, :, :]
            # heatmap (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            heatmap = predict.heatmap_from_predict(prob_pred,
                                                   net_config.num_keypoints)
            #heatmap = prob_pred
            _, _, heatmap_height, heatmap_width = heatmap.shape
            #print(raw_pred.shape)
            #print(prob_pred.shape)
            #print(depthmap_pred.shape)
            # Compute the coordinate
            #coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(heatmap, depthmap_pred)
            depth_pred = depth_pred[:, :, 0]
            # Concantate them
            #xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)

            # Compute loss
            depth_loss = weighted_l1_loss(depth_pred, keypoint_xy_depth[:, :,
                                                                        2],
                                          keypoint_weight[:, :, 2])
            heatmap_loss = heatmap_criterion(heatmap, target_heatmap)
            np.save('pred_heatmap.npy', heatmap.cpu().detach().numpy())
            np.save('target_heatmap.npy',
                    target_heatmap.cpu().detach().numpy())
            if idx % 100 == 0:
                print('depth loss:', depth_loss)
                print('heatmap loss:', heatmap_loss)
            #loss = weighted_l1_loss(xy_depth_pred, keypoint_xy_depth, keypoint_weight)
            #loss = depth_loss + 1500*heatmap_loss
            loss = heatmap_loss
            loss.backward()
            optimizer.step()

            # cleanup
            del loss

            # Log info
            #xy_error = float(weighted_l1_loss(xy_depth_pred[:, :, 0:2], keypoint_xy_depth[:, :, 0:2], keypoint_weight[:, :, 0:2]).item())
            #depth_error = float(weighted_l1_loss(xy_depth_pred[:, :, 2], keypoint_xy_depth[:, :, 2], keypoint_weight[:, :, 2]).item())
            keypoint_xy_pred, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(
                heatmap)
            xy_error = float(
                weighted_l1_loss(keypoint_xy_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            depth_error = float(
                weighted_l1_loss(depth_pred, keypoint_xy_depth[:, :, 2],
                                 keypoint_weight[:, :, 2]).item())
            if idx % 100 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                print('The averaged pixel error is (pixel in 256x256 image): ',
                      256 * xy_error / image.shape[0])
                print(
                    'The averaged depth error is (mm): ',
                    train_config.depth_image_scale * depth_error /
                    len(depth_pred))

            # Update info
            train_error_xy += float(xy_error)
            train_error_depth += float(depth_error)

        # The info at epoch level
        print('Epoch %d' % epoch)
        print(
            'The training averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy / len(dataset_train))
        print(
            'The training averaged depth error is (mm): ',
            train_config.depth_image_scale * train_error_depth /
            len(dataset_train))
        writer.add_scalar('average pixel error',
                          256 * train_error_xy / len(dataset_train), epoch)
        writer.add_scalar(
            'average depth error', train_config.depth_image_scale *
            train_error_depth / len(dataset_train), epoch)
    writer.close()
def train(checkpoint_dir: str,
          start_from_ckpnt: str = '',
          save_epoch_offset: int = 0):
    # Construct the dataset
    dataset_train, train_config = construct_dataset(is_train=True)
    dataset_val, val_config = construct_dataset(is_train=False)

    # And the dataloader
    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=16,
                              shuffle=True,
                              num_workers=4)
    loader_val = DataLoader(dataset=dataset_val,
                            batch_size=16,
                            shuffle=False,
                            num_workers=4)

    # Construct the regressor
    network, net_config = construct_network()
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))
    else:
        init_from_modelzoo(network, net_config)
    network.cuda()

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [20, 40],
                                                     gamma=0.1)

    # The loss for heatmap
    heatmap_criterion = torch.nn.MSELoss().cuda()

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 100 == 0 and epoch > 0:
            file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset)
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]
            target_heatmap = data[parameter.target_heatmap_key]

            # Upload to GPU
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()
            target_heatmap = target_heatmap.cuda()

            # To predict
            optimizer.zero_grad()
            raw_pred = network(image)
            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            _, _, heatmap_height, heatmap_width = prob_pred.shape

            # Compute loss
            loss = heatmap_criterion(prob_pred, target_heatmap)

            # Do update
            loss.backward()
            optimizer.step()

            # cleanup
            del loss

            # Do some pred and log
            keypoint_xy_pred, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(
                prob_pred)
            xy_error = float(
                weighted_l1_loss(keypoint_xy_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            if idx % 100 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                print('The averaged pixel error is (pixel in 256x256 image): ',
                      256 * xy_error / image.shape[0])

            # Update info
            train_error_xy += float(xy_error)

        # The info at epoch level
        print('Epoch %d' % epoch)
        print(
            'The training averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy / len(dataset_train))

        # Prepare info at epoch level
        network.eval()
        val_error_xy = 0

        # The validation iteration of the data
        for idx, data in enumerate(loader_val):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]

            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()

            # To predict
            pred = network(image)
            prob_pred = pred[:, 0:net_config.num_keypoints, :, :]
            _, _, heatmap_height, heatmap_width = prob_pred.shape

            # Compute the coordinate
            keypoint_xy_pred, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(
                prob_pred)
            xy_error = float(
                weighted_l1_loss(keypoint_xy_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())

            # Update info
            val_error_xy += float(xy_error)

        # The info at epoch level
        print(
            'The validation averaged pixel error is (pixel in 256x256 image): ',
            256 * val_error_xy / len(dataset_val))
def train(checkpoint_dir: str,
          start_from_ckpnt: str = '',
          save_epoch_offset: int = 0):
    # Construct the dataset
    dataset_train, train_config = construct_dataset(is_train=True)

    # And the dataloader
    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=32,
                              shuffle=True,
                              num_workers=4)

    # Construct the regressor
    network, net_config = construct_network()
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))
    else:
        init_from_modelzoo(network, net_config)
    network.cuda()

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 90],
                                                     gamma=0.1)

    # The loss for heatmap
    heatmap_criterion = torch.nn.MSELoss().cuda()

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 4 == 0 and epoch > 0:
            file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset)
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0
        train_error_depth = 0
        train_error_xy_heatmap = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]
            target_heatmap = data[parameter.target_heatmap_key]

            # Upload to cuda
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()
            target_heatmap = target_heatmap.cuda()

            # To predict
            optimizer.zero_grad()
            raw_pred = network(image)
            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:2 *
                                     net_config.num_keypoints, :, :]
            regress_heatmap = raw_pred[:, 2 * net_config.num_keypoints:, :, :]
            integral_heatmap = predict.heatmap_from_predict(
                prob_pred, net_config.num_keypoints)
            _, _, heatmap_height, heatmap_width = integral_heatmap.shape

            # Compute the coordinate
            coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
                integral_heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(integral_heatmap,
                                                   depthmap_pred)

            # Concantate them
            xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)

            # Compute loss
            loss = weighted_mse_loss(xy_depth_pred, keypoint_xy_depth,
                                     keypoint_weight)
            loss = loss + heatmap_loss_weight * heatmap_criterion(
                regress_heatmap, target_heatmap)

            # Do update
            loss.backward()
            optimizer.step()

            # Log info
            xy_error = float(
                weighted_l1_loss(xy_depth_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            depth_error = float(
                weighted_l1_loss(xy_depth_pred[:, :,
                                               2], keypoint_xy_depth[:, :, 2],
                                 keypoint_weight[:, :, 2]).item())
            keypoint_xy_pred_heatmap, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(
                regress_heatmap)
            xy_error_heatmap = float(
                weighted_l1_loss(keypoint_xy_pred_heatmap[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            if idx % 100 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                print('The averaged pixel error is (pixel in 256x256 image): ',
                      256 * xy_error / len(xy_depth_pred))
                print(
                    'The averaged depth error is (mm): ',
                    train_config.depth_image_scale * depth_error /
                    len(xy_depth_pred))
                print(
                    'The averaged heatmap argmax pixel error is (pixel in 256x256 image): ',
                    256 * xy_error_heatmap / len(xy_depth_pred))

            # Update info
            train_error_xy += float(xy_error)
            train_error_depth += float(depth_error)
            train_error_xy_heatmap += float(xy_error_heatmap)

        # The info at epoch level
        print('Epoch %d' % epoch)
        print(
            'The training averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy / len(dataset_train))
        print(
            'The training averaged depth error is (mm): ',
            train_config.depth_image_scale * train_error_depth /
            len(dataset_train))
        print(
            'The training averaged heatmap pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy_heatmap / len(dataset_train))