Exemplo n.º 1
0
def construct_dataset(is_train: bool, db_config_file_path, time_sequence_length=None):
    """
    -> (torch.utils.data.Dataset, SupervisedKeypointDatasetConfig)
    """
    # Construct the db
    db_config = SpartanSupvervisedKeypointDBConfig()
    db_config.keypoint_yaml_name = 'stick_2_keypoint_image.yaml'
    db_config.pdc_data_root = "/home/monti/Desktop/pdc"
    db_config.config_file_path = db_config_file_path
    database = SpartanSupervisedKeypointDatabase(db_config)

    # Construct torch dataset
    config = SupervisedKeypointDatasetConfig()
    config.network_in_patch_width = segmented_img_size
    config.network_in_patch_height = segmented_img_size
    config.network_out_map_width = 64
    config.network_out_map_height = 64
    config.image_database_list.append(database)
    config.is_train = is_train
    dataset = SupervisedKeypointDataset(config)

    if time_sequence_length:
        # Optionally wraps the dataset to query time series data
        dataset = TimeseriesWrapper(dataset, time_sequence_length)

    return dataset, config
def construct_dataset(
    is_train: bool
) -> (SupervisedKeypointDataset, SupervisedKeypointDatasetConfig):
    # Construct the db info
    db_config = SpartanSupvervisedKeypointDBConfig()
    db_config.keypoint_yaml_name = 'camera_config.yaml'
    db_config.pdc_data_root = '/home/wei/data/pdc'
    if is_train:
        db_config.config_file_path = '/home/wei/Coding/mankey/config/mugs_up_with_flat_logs.txt'
    else:
        db_config.config_file_path = '/home/wei/Coding/mankey/config/mugs_up_with_flat_test_logs.txt'

    # Construct the database
    database = SpartanSupervisedKeypointDatabase(db_config)

    # Construct torch dataset
    config = SupervisedKeypointDatasetConfig()
    config.network_in_patch_width = 256
    config.network_in_patch_height = 256
    config.network_out_map_width = 64
    config.network_out_map_height = 64
    config.image_database_list.append(database)
    config.is_train = is_train
    dataset = SupervisedKeypointDataset(config)
    return dataset, config
def construct_dataset(
    is_train: bool
) -> (SupervisedKeypointDataset, SupervisedKeypointDatasetConfig):
    # Construct the db info
    db_config = SpartanSupvervisedKeypointDBConfig()
    db_config.keypoint_yaml_name = 'mug_3_keypoint_image.yaml'
    db_config.pdc_data_root = '/home/luben/data/pdc'
    if is_train:
        db_config.config_file_path = '/home/luben/robotic-arm-task-oriented-manipulation/mankey/config/mugs_20201210.txt'
    else:
        db_config.config_file_path = '/home/luben/robotic-arm-task-oriented-manipulation/mankey/config/mugs_20201210.txt'

    # Construct the database
    database = SpartanSupervisedKeypointDatabase(db_config)

    # Construct torch dataset
    config = SupervisedKeypointDatasetConfig()
    config.network_in_patch_width = 256
    config.network_in_patch_height = 256
    config.network_out_map_width = 64
    config.network_out_map_height = 64
    config.image_database_list.append(database)
    config.is_train = is_train
    dataset = SupervisedKeypointDataset(config)
    return dataset, config
def construct_dataset(
    is_train: bool
) -> (torch.utils.data.Dataset, SupervisedKeypointDatasetConfig):
    # Construct the db
    db_config = SpartanSupvervisedKeypointDBConfig()
    db_config.keypoint_yaml_name = 'shoe_6_keypoint_image.yaml'
    db_config.pdc_data_root = '/home/wei/data/pdc'
    db_config.config_file_path = '/home/wei/Coding/mankey/config/boot_logs.txt'
    database = SpartanSupervisedKeypointDatabase(db_config)

    # Construct torch dataset
    config = SupervisedKeypointDatasetConfig()
    config.network_in_patch_width = 256
    config.network_in_patch_height = 256
    config.network_out_map_width = 64
    config.network_out_map_height = 64
    config.image_database_list.append(database)
    config.is_train = is_train
    dataset = SupervisedKeypointDataset(config)
    return dataset, config
Exemplo n.º 5
0
def main():
    # Construct the db
    db_config = SpartanSupvervisedKeypointDBConfig()
    db_config.keypoint_yaml_name = 'shoe_6_keypoint_image.yaml'
    db_config.pdc_data_root = '/home/wei/data/pdc'
    db_config.config_file_path = '/home/wei/Coding/mankey/config/shoe_logs.txt'
    database = SpartanSupervisedKeypointDatabase(db_config)

    # Construct torch dataset
    config = SupervisedKeypointDatasetConfig()
    config.network_in_patch_height = 256
    config.network_in_patch_width = 256
    config.network_out_map_width = 64
    config.network_out_map_height = 64
    config.image_database_list.append(database)
    config.is_train = False
    dataset = SupervisedKeypointDataset(config)

    # The counter
    min_keypoint_depth: float = 10000
    max_keypoint_depth: float = 0
    avg_keypoint_depth: float = 0
    counter = 0

    # Iterate
    for i in range(len(dataset)):
        processed_entry = dataset.get_processed_entry(dataset.entry_list[i])
        n_keypoint = processed_entry.keypoint_xy_depth.shape[1]
        for j in range(n_keypoint):
            min_keypoint_depth = min(min_keypoint_depth,
                                     processed_entry.keypoint_xy_depth[2, j])
            max_keypoint_depth = max(max_keypoint_depth,
                                     processed_entry.keypoint_xy_depth[2, j])
            avg_keypoint_depth += processed_entry.keypoint_xy_depth[2, j]
            counter += 1

    # Some output
    print('The min value is ', min_keypoint_depth)
    print('The max value is ', max_keypoint_depth)
    print('The average value is ', avg_keypoint_depth / float(counter))
Exemplo n.º 6
0
def save_loaded_img():
    from mankey.dataproc.spartan_supervised_db import SpartanSupvervisedKeypointDBConfig, SpartanSupervisedKeypointDatabase

    # Construct the db
    db_config = SpartanSupvervisedKeypointDBConfig()
    db_config.keypoint_yaml_name = 'shoe_6_keypoint_image.yaml'
    db_config.pdc_data_root = '/home/wei/data/pdc'
    db_config.config_file_path = '/home/wei/Coding/mankey/config/boot_logs.txt'
    database = SpartanSupervisedKeypointDatabase(db_config)

    # Construct torch dataset
    config = SupervisedKeypointDatasetConfig()
    config.network_in_patch_width = 256
    config.network_in_patch_height = 256
    config.network_out_map_width = 64
    config.network_out_map_height = 64
    config.image_database_list.append(database)
    config.is_train = False
    dataset = SupervisedKeypointDataset(config)

    # Simple check
    import os
    print(len(dataset))
    tmp_dir = os.path.join(os.path.dirname(__file__), 'tmp')
    if not os.path.exists(tmp_dir):
        os.mkdir(tmp_dir)

    # Save all the warped image
    from mankey.utils.imgproc import draw_image_keypoint, draw_visible_heatmap, get_visible_mask
    for i in range(min(1000, len(dataset))):
        idx = random.randint(0, len(dataset) - 1)
        processed_entry = dataset.get_processed_entry(dataset.entry_list[idx])
        rgb_keypoint = draw_image_keypoint(processed_entry.cropped_rgb,
                                           processed_entry.keypoint_xy_depth,
                                           processed_entry.keypoint_validity)
        cv2.imwrite(os.path.join(tmp_dir, 'image_%d_rgb.png' % idx),
                    rgb_keypoint)
        cv2.imwrite(os.path.join(tmp_dir, 'mask_image_%d_rgb.png' % idx),
                    get_visible_mask(processed_entry.cropped_binary_mask))
Exemplo n.º 7
0
def construct_dataset(
    is_train: bool
) -> (torch.utils.data.Dataset, SupervisedKeypointDatasetConfig):
    # Construct the db
    db_config = SpartanSupvervisedKeypointDBConfig()
    #db_config.keypoint_yaml_name = 'mug_3_keypoint_image.yaml'
    db_config.keypoint_yaml_name = 'peg_in_hole.yaml'
    db_config.pdc_data_root = '/tmp2/r09944001/data/pdc'
    if is_train:
        db_config.config_file_path = '/tmp2/r09944001/robot-peg-in-hole-task/mankey/config/box_insertion_fix_20210529.txt'
    else:
        db_config.config_file_path = '/tmp2/r09944001/robot-peg-in-hole-task/mankey/config/box_insertion_fix_20210529.txt'
    database = SpartanSupervisedKeypointDatabase(db_config)

    # Construct torch dataset
    config = SupervisedKeypointDatasetConfig()
    config.network_in_patch_width = 256
    config.network_in_patch_height = 256
    config.network_out_map_width = 64
    config.network_out_map_height = 64
    config.image_database_list.append(database)
    config.is_train = is_train
    dataset = SupervisedKeypointDataset(config)
    return dataset, config