예제 #1
0
def train(checkpoint_dir: str,
          start_from_ckpnt: str = '',
          save_epoch_offset: int = 0):
    # Construct the dataset
    dataset_train, train_config = construct_dataset(is_train=True)
    # dataset_val, val_config = construct_dataset(is_train=False)

    # And the dataloader
    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=64,
                              shuffle=True,
                              num_workers=4)
    # loader_val = DataLoader(dataset=dataset_val, batch_size=16, shuffle=False, num_workers=4)

    # Construct the regressor
    network, net_config = construct_network()
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))
    else:
        init_from_modelzoo(network, net_config)
    network.cuda()

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 90],
                                                     gamma=0.1)

    # The loss for heatmap
    heatmap_criterion = torch.nn.MSELoss().cuda()
    #heatmap_criterion = torch.nn.KLDivLoss().cuda()

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 20 == 0 and epoch > 0:
            file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset)
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0
        train_error_depth = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            #keypoint_xy_depth (batch_size, num_keypoint, xydepth)
            keypoint_weight = data[parameter.keypoint_validity_key]
            target_heatmap = data[parameter.target_heatmap_key]

            #if idx  == 0:
            #    np.save('rgbd.npy', image[0])
            #    np.save('target_heatmap.npy', target_heatmap[0])
            # Upload to GPU
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()
            target_heatmap = target_heatmap.cuda()

            # To predict
            optimizer.zero_grad()

            # raw_pred (batch_size, num_keypoint*2, network_out_map_height, network_out_map_width)
            # prob_pred (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            # depthmap_pred (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            raw_pred = network(image)
            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:, :, :]
            # heatmap (batch_size, num_keypoint, network_out_map_height, network_out_map_width)
            heatmap = predict.heatmap_from_predict(prob_pred,
                                                   net_config.num_keypoints)
            #heatmap = prob_pred
            _, _, heatmap_height, heatmap_width = heatmap.shape
            #print(raw_pred.shape)
            #print(prob_pred.shape)
            #print(depthmap_pred.shape)
            # Compute the coordinate
            #coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(heatmap, depthmap_pred)
            depth_pred = depth_pred[:, :, 0]
            # Concantate them
            #xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)

            # Compute loss
            depth_loss = weighted_l1_loss(depth_pred, keypoint_xy_depth[:, :,
                                                                        2],
                                          keypoint_weight[:, :, 2])
            heatmap_loss = heatmap_criterion(heatmap, target_heatmap)
            np.save('pred_heatmap.npy', heatmap.cpu().detach().numpy())
            np.save('target_heatmap.npy',
                    target_heatmap.cpu().detach().numpy())
            if idx % 100 == 0:
                print('depth loss:', depth_loss)
                print('heatmap loss:', heatmap_loss)
            #loss = weighted_l1_loss(xy_depth_pred, keypoint_xy_depth, keypoint_weight)
            #loss = depth_loss + 1500*heatmap_loss
            loss = heatmap_loss
            loss.backward()
            optimizer.step()

            # cleanup
            del loss

            # Log info
            #xy_error = float(weighted_l1_loss(xy_depth_pred[:, :, 0:2], keypoint_xy_depth[:, :, 0:2], keypoint_weight[:, :, 0:2]).item())
            #depth_error = float(weighted_l1_loss(xy_depth_pred[:, :, 2], keypoint_xy_depth[:, :, 2], keypoint_weight[:, :, 2]).item())
            keypoint_xy_pred, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(
                heatmap)
            xy_error = float(
                weighted_l1_loss(keypoint_xy_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            depth_error = float(
                weighted_l1_loss(depth_pred, keypoint_xy_depth[:, :, 2],
                                 keypoint_weight[:, :, 2]).item())
            if idx % 100 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                print('The averaged pixel error is (pixel in 256x256 image): ',
                      256 * xy_error / image.shape[0])
                print(
                    'The averaged depth error is (mm): ',
                    train_config.depth_image_scale * depth_error /
                    len(depth_pred))

            # Update info
            train_error_xy += float(xy_error)
            train_error_depth += float(depth_error)

        # The info at epoch level
        print('Epoch %d' % epoch)
        print(
            'The training averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy / len(dataset_train))
        print(
            'The training averaged depth error is (mm): ',
            train_config.depth_image_scale * train_error_depth /
            len(dataset_train))
        writer.add_scalar('average pixel error',
                          256 * train_error_xy / len(dataset_train), epoch)
        writer.add_scalar(
            'average depth error', train_config.depth_image_scale *
            train_error_depth / len(dataset_train), epoch)
    writer.close()
def train(checkpoint_dir: str,
          start_from_ckpnt: str = '',
          save_epoch_offset: int = 0):
    # Construct the dataset
    dataset_train, train_config = construct_dataset(is_train=True)
    dataset_val, val_config = construct_dataset(is_train=False)

    # And the dataloader
    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=16,
                              shuffle=True,
                              num_workers=4)
    loader_val = DataLoader(dataset=dataset_val,
                            batch_size=16,
                            shuffle=False,
                            num_workers=4)

    # Construct the regressor
    network, net_config = construct_network()
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))
    else:
        init_from_modelzoo(network, net_config)
    network.cuda()

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [20, 40],
                                                     gamma=0.1)

    # The loss for heatmap
    heatmap_criterion = torch.nn.MSELoss().cuda()

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 100 == 0 and epoch > 0:
            file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset)
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]
            target_heatmap = data[parameter.target_heatmap_key]

            # Upload to GPU
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()
            target_heatmap = target_heatmap.cuda()

            # To predict
            optimizer.zero_grad()
            raw_pred = network(image)
            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            _, _, heatmap_height, heatmap_width = prob_pred.shape

            # Compute loss
            loss = heatmap_criterion(prob_pred, target_heatmap)

            # Do update
            loss.backward()
            optimizer.step()

            # cleanup
            del loss

            # Do some pred and log
            keypoint_xy_pred, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(
                prob_pred)
            xy_error = float(
                weighted_l1_loss(keypoint_xy_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            if idx % 100 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                print('The averaged pixel error is (pixel in 256x256 image): ',
                      256 * xy_error / image.shape[0])

            # Update info
            train_error_xy += float(xy_error)

        # The info at epoch level
        print('Epoch %d' % epoch)
        print(
            'The training averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy / len(dataset_train))

        # Prepare info at epoch level
        network.eval()
        val_error_xy = 0

        # The validation iteration of the data
        for idx, data in enumerate(loader_val):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]

            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()

            # To predict
            pred = network(image)
            prob_pred = pred[:, 0:net_config.num_keypoints, :, :]
            _, _, heatmap_height, heatmap_width = prob_pred.shape

            # Compute the coordinate
            keypoint_xy_pred, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(
                prob_pred)
            xy_error = float(
                weighted_l1_loss(keypoint_xy_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())

            # Update info
            val_error_xy += float(xy_error)

        # The info at epoch level
        print(
            'The validation averaged pixel error is (pixel in 256x256 image): ',
            256 * val_error_xy / len(dataset_val))
def train(checkpoint_dir: str,
          start_from_ckpnt: str = '',
          save_epoch_offset: int = 0):
    # Construct the dataset
    dataset_train, train_config = construct_dataset(is_train=True)
    # dataset_val, val_config = construct_dataset(is_train=False)

    # And the dataloader
    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=8,
                              shuffle=True,
                              num_workers=4)
    # loader_val = DataLoader(dataset=dataset_val, batch_size=16, shuffle=False, num_workers=4)

    # Construct the regressor
    network, net_config = construct_network()

    # To cuda
    network = torch.nn.DataParallel(network).cuda()
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 80],
                                                     gamma=0.1)

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 2 == 0 and epoch > 0:
            file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset)
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0
        train_error_depth = 0
        train_error_xy_internal = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]

            # Move to gpu
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()

            # To predict
            optimizer.zero_grad()
            raw_pred = network(image)

            # The last stage
            raw_pred_last = raw_pred[-1]
            prob_pred_last = raw_pred_last[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred_last = raw_pred_last[:,
                                               net_config.num_keypoints:, :, :]
            heatmap_last = predict.heatmap_from_predict(
                prob_pred_last, net_config.num_keypoints)
            _, _, heatmap_height, heatmap_width = heatmap_last.shape

            # Compute the coordinate
            coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
                heatmap_last, net_config.num_keypoints)
            depth_pred = predict.depth_integration(heatmap_last,
                                                   depthmap_pred_last)

            # Concantate them
            xy_depth_pred = torch.cat([coord_x, coord_y, depth_pred], dim=2)

            # Compute loss
            loss = weighted_l1_loss(xy_depth_pred, keypoint_xy_depth,
                                    keypoint_weight)

            # For all other layers
            for stage_i in range(len(raw_pred) - 1):
                prob_pred_i = raw_pred[
                    stage_i]  # Only 2d prediction on previous layers
                assert prob_pred_i.shape == prob_pred_last.shape
                heatmap_i = predict.heatmap_from_predict(
                    prob_pred_i, net_config.num_keypoints)
                coord_x_i, coord_y_i = predict.heatmap2d_to_normalized_imgcoord_gpu(
                    heatmap_i, net_config.num_keypoints)
                xy_pred_i = torch.cat([coord_x_i, coord_y_i], dim=2)
                loss = loss + weighted_l1_loss(xy_pred_i,
                                               keypoint_xy_depth[:, :, 0:2],
                                               keypoint_weight[:, :, 0:2])

            # The SGD step
            loss.backward()
            optimizer.step()
            del loss

            # Log info
            xy_error = float(
                weighted_l1_loss(xy_depth_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            depth_error = float(
                weighted_l1_loss(xy_depth_pred[:, :,
                                               2], keypoint_xy_depth[:, :, 2],
                                 keypoint_weight[:, :, 2]).item())
            # The error of internal stage
            keypoint_xy_pred_internal, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(
                raw_pred[0])
            xy_error_internal = float(
                weighted_l1_loss(keypoint_xy_pred_internal[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            if idx % 100 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                print('The averaged pixel error is (pixel in 256x256 image): ',
                      256 * xy_error / len(xy_depth_pred))
                print('The averaged depth error is (mm): ',
                      256 * depth_error / len(xy_depth_pred))
                print(
                    'The averaged internal pixel error is (pixel in 256x256 image): ',
                    256 * xy_error_internal / image.shape[0])

            # Update info
            train_error_xy += float(xy_error)
            train_error_depth += float(depth_error)
            train_error_xy_internal += float(xy_error_internal)

        # The info at epoch level
        print('Epoch %d' % epoch)
        print(
            'The training averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy / len(dataset_train))
        print(
            'The training averaged depth error is (mm): ',
            train_config.depth_image_scale * train_error_depth /
            len(dataset_train))
        print(
            'The training internal averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy_internal / len(dataset_train))
def train(checkpoint_dir: str,
          start_from_ckpnt: str = '',
          save_epoch_offset: int = 0):
    # Construct the dataset
    dataset_train, train_config = construct_dataset(is_train=True)

    # And the dataloader
    loader_train = DataLoader(dataset=dataset_train,
                              batch_size=32,
                              shuffle=True,
                              num_workers=4)

    # Construct the regressor
    network, net_config = construct_network()
    if len(start_from_ckpnt) > 0:
        network.load_state_dict(torch.load(start_from_ckpnt))
    else:
        init_from_modelzoo(network, net_config)
    network.cuda()

    # The checkpoint
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

    # The optimizer and scheduler
    optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 90],
                                                     gamma=0.1)

    # The loss for heatmap
    heatmap_criterion = torch.nn.MSELoss().cuda()

    # The training loop
    for epoch in range(n_epoch):
        # Save the network
        if epoch % 4 == 0 and epoch > 0:
            file_name = 'checkpoint-%d.pth' % (epoch + save_epoch_offset)
            checkpoint_path = os.path.join(checkpoint_dir, file_name)
            print('Save the network at %s' % checkpoint_path)
            torch.save(network.state_dict(), checkpoint_path)

        # Prepare info for training
        network.train()
        train_error_xy = 0
        train_error_depth = 0
        train_error_xy_heatmap = 0

        # The learning rate step
        scheduler.step()
        for param_group in optimizer.param_groups:
            print('The learning rate is ', param_group['lr'])

        # The training iteration over the dataset
        for idx, data in enumerate(loader_train):
            # Get the data
            image = data[parameter.rgbd_image_key]
            keypoint_xy_depth = data[parameter.keypoint_xyd_key]
            keypoint_weight = data[parameter.keypoint_validity_key]
            target_heatmap = data[parameter.target_heatmap_key]

            # Upload to cuda
            image = image.cuda()
            keypoint_xy_depth = keypoint_xy_depth.cuda()
            keypoint_weight = keypoint_weight.cuda()
            target_heatmap = target_heatmap.cuda()

            # To predict
            optimizer.zero_grad()
            raw_pred = network(image)
            prob_pred = raw_pred[:, 0:net_config.num_keypoints, :, :]
            depthmap_pred = raw_pred[:, net_config.num_keypoints:2 *
                                     net_config.num_keypoints, :, :]
            regress_heatmap = raw_pred[:, 2 * net_config.num_keypoints:, :, :]
            integral_heatmap = predict.heatmap_from_predict(
                prob_pred, net_config.num_keypoints)
            _, _, heatmap_height, heatmap_width = integral_heatmap.shape

            # Compute the coordinate
            coord_x, coord_y = predict.heatmap2d_to_normalized_imgcoord_gpu(
                integral_heatmap, net_config.num_keypoints)
            depth_pred = predict.depth_integration(integral_heatmap,
                                                   depthmap_pred)

            # Concantate them
            xy_depth_pred = torch.cat((coord_x, coord_y, depth_pred), dim=2)

            # Compute loss
            loss = weighted_mse_loss(xy_depth_pred, keypoint_xy_depth,
                                     keypoint_weight)
            loss = loss + heatmap_loss_weight * heatmap_criterion(
                regress_heatmap, target_heatmap)

            # Do update
            loss.backward()
            optimizer.step()

            # Log info
            xy_error = float(
                weighted_l1_loss(xy_depth_pred[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            depth_error = float(
                weighted_l1_loss(xy_depth_pred[:, :,
                                               2], keypoint_xy_depth[:, :, 2],
                                 keypoint_weight[:, :, 2]).item())
            keypoint_xy_pred_heatmap, _ = predict.heatmap2d_to_normalized_imgcoord_argmax(
                regress_heatmap)
            xy_error_heatmap = float(
                weighted_l1_loss(keypoint_xy_pred_heatmap[:, :, 0:2],
                                 keypoint_xy_depth[:, :, 0:2],
                                 keypoint_weight[:, :, 0:2]).item())
            if idx % 100 == 0:
                print('Iteration %d in epoch %d' % (idx, epoch))
                print('The averaged pixel error is (pixel in 256x256 image): ',
                      256 * xy_error / len(xy_depth_pred))
                print(
                    'The averaged depth error is (mm): ',
                    train_config.depth_image_scale * depth_error /
                    len(xy_depth_pred))
                print(
                    'The averaged heatmap argmax pixel error is (pixel in 256x256 image): ',
                    256 * xy_error_heatmap / len(xy_depth_pred))

            # Update info
            train_error_xy += float(xy_error)
            train_error_depth += float(depth_error)
            train_error_xy_heatmap += float(xy_error_heatmap)

        # The info at epoch level
        print('Epoch %d' % epoch)
        print(
            'The training averaged pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy / len(dataset_train))
        print(
            'The training averaged depth error is (mm): ',
            train_config.depth_image_scale * train_error_depth /
            len(dataset_train))
        print(
            'The training averaged heatmap pixel error is (pixel in 256x256 image): ',
            256 * train_error_xy_heatmap / len(dataset_train))