def train(checkpoint_dir: str,
          start_from_ckpnt: str = '',
          save_epoch_offset: int = 0):
    # Construct the dataset
    dataset_train, train_config = construct_dataset(is_train=True)
    # dataset_val, val_config = construct_dataset(is_train=False)

    # And the dataloader
    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=1,
                              shuffle=True,
                              num_workers=1)
    # loader_val = DataLoader(dataset=dataset_val, batch_size=16, shuffle=False, num_workers=4)

    # Construct the regressor
    network, net_config = construct_network()
    #control_network = ControlNetwork(in_channel=int(net_config.num_keypoints * net_config.depth_per_keypoint * 256/4 * 256/4))
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))
    else:
        init_from_modelzoo(network, net_config)
    network.to(device)
    #control_network.to(device)

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # root mean square error loss
    criterion_rmse = RMSELoss()
    criterion_cos = torch.nn.CosineSimilarity(dim=1)
    criterion_bce = torch.nn.BCELoss(reduction='none')
    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 90],
                                                     gamma=0.1)

    # The training loop
    for epoch in range(1, n_epoch + 1):
        # Save the network
        if epoch % 20 == 0 and epoch > 0:
            file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset)
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0
        train_error_depth = 0
        train_error_move = 0
        train_error_rot = 0
        train_error_xyz = 0
        train_error_step = 0
        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key][0]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key][0]
            keypoint_weight = data[parameter.keypoint_validity_key][0]
            delta_rot = data[parameter.delta_rot_key][0]
            delta_xyz = data[parameter.delta_xyz_key][0]
            gripper_pose = data[parameter.gripper_pose_key][0]
            step_size = data[parameter.step_size_key][0]

            # Upload to GPU
            image = image.to(device)
            keypoint_xy_depth = keypoint_xy_depth.to(device)
            keypoint_weight = keypoint_weight.to(device)
            delta_rot = delta_rot.to(device)
            delta_xyz = delta_xyz.to(device)
            gripper_pose = gripper_pose.to(device)
            step_size = step_size.to(device)
            #print('delta_rot',delta_rot.shape)
            #print('delta_xyz',delta_xyz.shape)
            #print('gripper_pose',gripper_pose.shape)
            #print('step_size',step_size.shape)
            # To predict
            optimizer.zero_grad()

            # raw_pred (batch_size, num_keypoint*2, network_out_map_height, network_out_map_width)
            # prob_pred (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            # depthmap_pred (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            xy_depth_pred, delta_rot_pred, delta_xyz_pred, step_size_pred = network(
                image,
                gripper_pose,
                device,
                enableKeypointPos=enableKeypointPos)
            #print((1-criterion_cos(torch.tensor([[0.01,0.01,0.01],[0.01,0.01,0.01]]).to(device), torch.tensor([[0.0,0.0,0.0],[0.0,0.0,0.0]]).to(device))).mean())
            #gripper control network
            #raw_pred_flatten = torch.flatten(raw_pred, start_dim=1)
            #delta_rot_pred, delta_xyz_pred, step_size_pred = control_network(raw_pred_flatten)

            #identity = torch.eye(3).unsqueeze(0).repeat(image.shape[0],1,1).to(device)
            #identity_hat = torch.matmul(torch.transpose(delta_rot_pred, 1, 2), delta_rot)
            #loss_r = criterion_rmse(identity, identity_hat)
            loss_r = criterion_rmse(delta_rot_pred, delta_rot)
            loss_t = (1 - criterion_cos(delta_xyz_pred, delta_xyz)
                      ).mean() + criterion_rmse(delta_xyz_pred, delta_xyz)
            #loss_t = criterion_rmse(delta_xyz_pred, delta_xyz)
            loss_s = criterion_bce(step_size_pred, step_size)
            loss_s = loss_s.mean()
            ''' 
            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:, :, :]
            # heatmap (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            heatmap = predict.heatmap_from_predict(prob_pred, net_config.num_keypoints)
            _, _, heatmap_height, heatmap_width = heatmap.shape
            #print(raw_pred.shape)
            #print(prob_pred.shape)
            #print(depthmap_pred.shape)
            # Compute the coordinate
            if device == 'cpu':
                coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_cpu(heatmap, net_config.num_keypoints)
            else:
                coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(heatmap, depthmap_pred)
            # Concantate them
            xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)
            '''
            # Compute loss
            loss_kpt = weighted_l1_loss(xy_depth_pred, keypoint_xy_depth,
                                        keypoint_weight)
            loss = loss_kpt * 15 + loss_r * 10 + loss_t * 10 + loss_s
            loss.backward()
            optimizer.step()

            # Log info
            xy_error = float(
                weighted_l1_loss(xy_depth_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            depth_error = float(
                weighted_l1_loss(xy_depth_pred[:, :,
                                               2], keypoint_xy_depth[:, :, 2],
                                 keypoint_weight[:, :, 2]).item())
            '''
            if idx % 100 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                print('The averaged pixel error is (pixel in 256x256 image): ', 256 * xy_error / len(xy_depth_pred))
                print('The averaged depth error is (mm): ', 256 * depth_error / len(xy_depth_pred))
                print('The move error is', loss_move.item())
            '''
            # Update info
            train_error_xy += float(xy_error)
            train_error_depth += float(depth_error)
            train_error_rot += float(loss_r)
            train_error_xyz += float(loss_t)
            train_error_step += float(loss_s)
            # cleanup
            del loss

        # The info at epoch level
        print('Epoch %d' % epoch)
        print(
            'The training averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy / len(dataset_train))
        print(
            'The training averaged depth error is (mm): ',
            train_config.depth_image_scale * train_error_depth /
            len(dataset_train))
        #print('The training averaged move error is: ', train_error_move / len(dataset_train))
        print('The training averaged rot error is: ',
              train_error_rot / len(dataset_train))
        print('The training averaged xyz error is: ',
              train_error_xyz / len(dataset_train))
        print('The training averaged step error is: ',
              train_error_step / len(dataset_train))
        writer.add_scalar('average pixel error',
                          256 * train_error_xy / len(dataset_train), epoch)
        writer.add_scalar(
            'average depth error', train_config.depth_image_scale *
            train_error_depth / len(dataset_train), epoch)
        writer.add_scalar('average move error',
                          train_error_move / len(dataset_train), epoch)
        writer.add_scalar('average rot error',
                          train_error_rot / len(dataset_train), epoch)
        writer.add_scalar('average xyz error',
                          train_error_xyz / len(dataset_train), epoch)
        writer.add_scalar('average step error',
                          train_error_step / len(dataset_train), epoch)
    writer.close()
Пример #2
0
def train(checkpoint_dir: str,
          start_from_ckpnt: str = '',
          save_epoch_offset: int = 0):
    # Construct the dataset
    dataset_train, train_config = construct_dataset(is_train=True)
    # dataset_val, val_config = construct_dataset(is_train=False)

    # And the dataloader
    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=64,
                              shuffle=True,
                              num_workers=4)
    # loader_val = DataLoader(dataset=dataset_val, batch_size=16, shuffle=False, num_workers=4)

    # Construct the regressor
    network, net_config = construct_network()
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))
    else:
        init_from_modelzoo(network, net_config)
    network.cuda()

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 90],
                                                     gamma=0.1)

    # The loss for heatmap
    heatmap_criterion = torch.nn.MSELoss().cuda()
    #heatmap_criterion = torch.nn.KLDivLoss().cuda()

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 20 == 0 and epoch > 0:
            file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset)
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0
        train_error_depth = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            #keypoint_xy_depth (batch_size, num_keypoint, xydepth)
            keypoint_weight = data[parameter.keypoint_validity_key]
            target_heatmap = data[parameter.target_heatmap_key]

            #if idx  == 0:
            #    np.save('rgbd.npy', image[0])
            #    np.save('target_heatmap.npy', target_heatmap[0])
            # Upload to GPU
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()
            target_heatmap = target_heatmap.cuda()

            # To predict
            optimizer.zero_grad()

            # raw_pred (batch_size, num_keypoint*2, network_out_map_height, network_out_map_width)
            # prob_pred (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            # depthmap_pred (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            raw_pred = network(image)
            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:, :, :]
            # heatmap (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            heatmap = predict.heatmap_from_predict(prob_pred,
                                                   net_config.num_keypoints)
            #heatmap = prob_pred
            _, _, heatmap_height, heatmap_width = heatmap.shape
            #print(raw_pred.shape)
            #print(prob_pred.shape)
            #print(depthmap_pred.shape)
            # Compute the coordinate
            #coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(heatmap, depthmap_pred)
            depth_pred = depth_pred[:, :, 0]
            # Concantate them
            #xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)

            # Compute loss
            depth_loss = weighted_l1_loss(depth_pred, keypoint_xy_depth[:, :,
                                                                        2],
                                          keypoint_weight[:, :, 2])
            heatmap_loss = heatmap_criterion(heatmap, target_heatmap)
            np.save('pred_heatmap.npy', heatmap.cpu().detach().numpy())
            np.save('target_heatmap.npy',
                    target_heatmap.cpu().detach().numpy())
            if idx % 100 == 0:
                print('depth loss:', depth_loss)
                print('heatmap loss:', heatmap_loss)
            #loss = weighted_l1_loss(xy_depth_pred, keypoint_xy_depth, keypoint_weight)
            #loss = depth_loss + 1500*heatmap_loss
            loss = heatmap_loss
            loss.backward()
            optimizer.step()

            # cleanup
            del loss

            # Log info
            #xy_error = float(weighted_l1_loss(xy_depth_pred[:, :, 0:2], keypoint_xy_depth[:, :, 0:2], keypoint_weight[:, :, 0:2]).item())
            #depth_error = float(weighted_l1_loss(xy_depth_pred[:, :, 2], keypoint_xy_depth[:, :, 2], keypoint_weight[:, :, 2]).item())
            keypoint_xy_pred, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(
                heatmap)
            xy_error = float(
                weighted_l1_loss(keypoint_xy_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            depth_error = float(
                weighted_l1_loss(depth_pred, keypoint_xy_depth[:, :, 2],
                                 keypoint_weight[:, :, 2]).item())
            if idx % 100 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                print('The averaged pixel error is (pixel in 256x256 image): ',
                      256 * xy_error / image.shape[0])
                print(
                    'The averaged depth error is (mm): ',
                    train_config.depth_image_scale * depth_error /
                    len(depth_pred))

            # Update info
            train_error_xy += float(xy_error)
            train_error_depth += float(depth_error)

        # The info at epoch level
        print('Epoch %d' % epoch)
        print(
            'The training averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy / len(dataset_train))
        print(
            'The training averaged depth error is (mm): ',
            train_config.depth_image_scale * train_error_depth /
            len(dataset_train))
        writer.add_scalar('average pixel error',
                          256 * train_error_xy / len(dataset_train), epoch)
        writer.add_scalar(
            'average depth error', train_config.depth_image_scale *
            train_error_depth / len(dataset_train), epoch)
    writer.close()
def train(checkpoint_dir: str,
          start_from_ckpnt: str = '',
          save_epoch_offset: int = 0):
    # Construct the dataset
    dataset_train, train_config = construct_dataset(is_train=True)
    dataset_val, val_config = construct_dataset(is_train=False)

    # And the dataloader
    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=16,
                              shuffle=True,
                              num_workers=4)
    loader_val = DataLoader(dataset=dataset_val,
                            batch_size=16,
                            shuffle=False,
                            num_workers=4)

    # Construct the regressor
    network, net_config = construct_network()
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))
    else:
        init_from_modelzoo(network, net_config)
    network.cuda()

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [20, 40],
                                                     gamma=0.1)

    # The loss for heatmap
    heatmap_criterion = torch.nn.MSELoss().cuda()

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 100 == 0 and epoch > 0:
            file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset)
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]
            target_heatmap = data[parameter.target_heatmap_key]

            # Upload to GPU
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()
            target_heatmap = target_heatmap.cuda()

            # To predict
            optimizer.zero_grad()
            raw_pred = network(image)
            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            _, _, heatmap_height, heatmap_width = prob_pred.shape

            # Compute loss
            loss = heatmap_criterion(prob_pred, target_heatmap)

            # Do update
            loss.backward()
            optimizer.step()

            # cleanup
            del loss

            # Do some pred and log
            keypoint_xy_pred, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(
                prob_pred)
            xy_error = float(
                weighted_l1_loss(keypoint_xy_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            if idx % 100 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                print('The averaged pixel error is (pixel in 256x256 image): ',
                      256 * xy_error / image.shape[0])

            # Update info
            train_error_xy += float(xy_error)

        # The info at epoch level
        print('Epoch %d' % epoch)
        print(
            'The training averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy / len(dataset_train))

        # Prepare info at epoch level
        network.eval()
        val_error_xy = 0

        # The validation iteration of the data
        for idx, data in enumerate(loader_val):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]

            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()

            # To predict
            pred = network(image)
            prob_pred = pred[:, 0:net_config.num_keypoints, :, :]
            _, _, heatmap_height, heatmap_width = prob_pred.shape

            # Compute the coordinate
            keypoint_xy_pred, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(
                prob_pred)
            xy_error = float(
                weighted_l1_loss(keypoint_xy_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())

            # Update info
            val_error_xy += float(xy_error)

        # The info at epoch level
        print(
            'The validation averaged pixel error is (pixel in 256x256 image): ',
            256 * val_error_xy / len(dataset_val))
def train(checkpoint_dir: str,
          start_from_ckpnt: str = '',
          save_epoch_offset: int = 0):
    # Construct the dataset
    dataset_train, train_config = construct_dataset(is_train=True)
    # dataset_val, val_config = construct_dataset(is_train=False)

    # And the dataloader
    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=8,
                              shuffle=True,
                              num_workers=4)
    # loader_val = DataLoader(dataset=dataset_val, batch_size=16, shuffle=False, num_workers=4)

    # Construct the regressor
    network, net_config = construct_network()

    # To cuda
    network = torch.nn.DataParallel(network).cuda()
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 80],
                                                     gamma=0.1)

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 2 == 0 and epoch > 0:
            file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset)
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0
        train_error_depth = 0
        train_error_xy_internal = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]

            # Move to gpu
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()

            # To predict
            optimizer.zero_grad()
            raw_pred = network(image)

            # The last stage
            raw_pred_last = raw_pred[-1]
            prob_pred_last = raw_pred_last[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred_last = raw_pred_last[:,
                                               net_config.num_keypoints:, :, :]
            heatmap_last = predict.heatmap_from_predict(
                prob_pred_last, net_config.num_keypoints)
            _, _, heatmap_height, heatmap_width = heatmap_last.shape

            # Compute the coordinate
            coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
                heatmap_last, net_config.num_keypoints)
            depth_pred = predict.depth_integration(heatmap_last,
                                                   depthmap_pred_last)

            # Concantate them
            xy_depth_pred = torch.cat([coord_x, coord_y, depth_pred], dim=2)

            # Compute loss
            loss = weighted_l1_loss(xy_depth_pred, keypoint_xy_depth,
                                    keypoint_weight)

            # For all other layers
            for stage_i in range(len(raw_pred) - 1):
                prob_pred_i = raw_pred[
                    stage_i]  # Only 2d prediction on previous layers
                assert prob_pred_i.shape == prob_pred_last.shape
                heatmap_i = predict.heatmap_from_predict(
                    prob_pred_i, net_config.num_keypoints)
                coord_x_i, coord_y_i = predict.heatmap2d_to_normalized_imgcoord_gpu(
                    heatmap_i, net_config.num_keypoints)
                xy_pred_i = torch.cat([coord_x_i, coord_y_i], dim=2)
                loss = loss + weighted_l1_loss(xy_pred_i,
                                               keypoint_xy_depth[:, :, 0:2],
                                               keypoint_weight[:, :, 0:2])

            # The SGD step
            loss.backward()
            optimizer.step()
            del loss

            # Log info
            xy_error = float(
                weighted_l1_loss(xy_depth_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            depth_error = float(
                weighted_l1_loss(xy_depth_pred[:, :,
                                               2], keypoint_xy_depth[:, :, 2],
                                 keypoint_weight[:, :, 2]).item())
            # The error of internal stage
            keypoint_xy_pred_internal, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(
                raw_pred[0])
            xy_error_internal = float(
                weighted_l1_loss(keypoint_xy_pred_internal[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            if idx % 100 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                print('The averaged pixel error is (pixel in 256x256 image): ',
                      256 * xy_error / len(xy_depth_pred))
                print('The averaged depth error is (mm): ',
                      256 * depth_error / len(xy_depth_pred))
                print(
                    'The averaged internal pixel error is (pixel in 256x256 image): ',
                    256 * xy_error_internal / image.shape[0])

            # Update info
            train_error_xy += float(xy_error)
            train_error_depth += float(depth_error)
            train_error_xy_internal += float(xy_error_internal)

        # The info at epoch level
        print('Epoch %d' % epoch)
        print(
            'The training averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy / len(dataset_train))
        print(
            'The training averaged depth error is (mm): ',
            train_config.depth_image_scale * train_error_depth /
            len(dataset_train))
        print(
            'The training internal averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy_internal / len(dataset_train))
def train(checkpoint_dir: str,
          start_from_ckpnt: str = '',
          save_epoch_offset: int = 0):
    # Construct the dataset
    dataset_train, train_config = construct_dataset(is_train=True)

    # And the dataloader
    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=32,
                              shuffle=True,
                              num_workers=4)

    # Construct the regressor
    network, net_config = construct_network()
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))
    else:
        init_from_modelzoo(network, net_config)
    network.cuda()

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 90],
                                                     gamma=0.1)

    # The loss for heatmap
    heatmap_criterion = torch.nn.MSELoss().cuda()

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 4 == 0 and epoch > 0:
            file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset)
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0
        train_error_depth = 0
        train_error_xy_heatmap = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]
            target_heatmap = data[parameter.target_heatmap_key]

            # Upload to cuda
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()
            target_heatmap = target_heatmap.cuda()

            # To predict
            optimizer.zero_grad()
            raw_pred = network(image)
            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:2 *
                                     net_config.num_keypoints, :, :]
            regress_heatmap = raw_pred[:, 2 * net_config.num_keypoints:, :, :]
            integral_heatmap = predict.heatmap_from_predict(
                prob_pred, net_config.num_keypoints)
            _, _, heatmap_height, heatmap_width = integral_heatmap.shape

            # Compute the coordinate
            coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
                integral_heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(integral_heatmap,
                                                   depthmap_pred)

            # Concantate them
            xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)

            # Compute loss
            loss = weighted_mse_loss(xy_depth_pred, keypoint_xy_depth,
                                     keypoint_weight)
            loss = loss + heatmap_loss_weight * heatmap_criterion(
                regress_heatmap, target_heatmap)

            # Do update
            loss.backward()
            optimizer.step()

            # Log info
            xy_error = float(
                weighted_l1_loss(xy_depth_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            depth_error = float(
                weighted_l1_loss(xy_depth_pred[:, :,
                                               2], keypoint_xy_depth[:, :, 2],
                                 keypoint_weight[:, :, 2]).item())
            keypoint_xy_pred_heatmap, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(
                regress_heatmap)
            xy_error_heatmap = float(
                weighted_l1_loss(keypoint_xy_pred_heatmap[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            if idx % 100 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                print('The averaged pixel error is (pixel in 256x256 image): ',
                      256 * xy_error / len(xy_depth_pred))
                print(
                    'The averaged depth error is (mm): ',
                    train_config.depth_image_scale * depth_error /
                    len(xy_depth_pred))
                print(
                    'The averaged heatmap argmax pixel error is (pixel in 256x256 image): ',
                    256 * xy_error_heatmap / len(xy_depth_pred))

            # Update info
            train_error_xy += float(xy_error)
            train_error_depth += float(depth_error)
            train_error_xy_heatmap += float(xy_error_heatmap)

        # The info at epoch level
        print('Epoch %d' % epoch)
        print(
            'The training averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy / len(dataset_train))
        print(
            'The training averaged depth error is (mm): ',
            train_config.depth_image_scale * train_error_depth /
            len(dataset_train))
        print(
            'The training averaged heatmap pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy_heatmap / len(dataset_train))