def construct_dataset(is_train: bool, db_config_file_path, time_sequence_length=None): """ -> (torch.utils.data.Dataset, SupervisedKeypointDatasetConfig) """ # Construct the db db_config = SpartanSupvervisedKeypointDBConfig() db_config.keypoint_yaml_name = 'stick_2_keypoint_image.yaml' db_config.pdc_data_root = "/home/monti/Desktop/pdc" db_config.config_file_path = db_config_file_path database = SpartanSupervisedKeypointDatabase(db_config) # Construct torch dataset config = SupervisedKeypointDatasetConfig() config.network_in_patch_width = segmented_img_size config.network_in_patch_height = segmented_img_size config.network_out_map_width = 64 config.network_out_map_height = 64 config.image_database_list.append(database) config.is_train = is_train dataset = SupervisedKeypointDataset(config) if time_sequence_length: # Optionally wraps the dataset to query time series data dataset = TimeseriesWrapper(dataset, time_sequence_length) return dataset, config
def construct_dataset( is_train: bool ) -> (SupervisedKeypointDataset, SupervisedKeypointDatasetConfig): # Construct the db info db_config = SpartanSupvervisedKeypointDBConfig() db_config.keypoint_yaml_name = 'camera_config.yaml' db_config.pdc_data_root = '/home/wei/data/pdc' if is_train: db_config.config_file_path = '/home/wei/Coding/mankey/config/mugs_up_with_flat_logs.txt' else: db_config.config_file_path = '/home/wei/Coding/mankey/config/mugs_up_with_flat_test_logs.txt' # Construct the database database = SpartanSupervisedKeypointDatabase(db_config) # Construct torch dataset config = SupervisedKeypointDatasetConfig() config.network_in_patch_width = 256 config.network_in_patch_height = 256 config.network_out_map_width = 64 config.network_out_map_height = 64 config.image_database_list.append(database) config.is_train = is_train dataset = SupervisedKeypointDataset(config) return dataset, config
def construct_dataset( is_train: bool ) -> (SupervisedKeypointDataset, SupervisedKeypointDatasetConfig): # Construct the db info db_config = SpartanSupvervisedKeypointDBConfig() db_config.keypoint_yaml_name = 'mug_3_keypoint_image.yaml' db_config.pdc_data_root = '/home/luben/data/pdc' if is_train: db_config.config_file_path = '/home/luben/robotic-arm-task-oriented-manipulation/mankey/config/mugs_20201210.txt' else: db_config.config_file_path = '/home/luben/robotic-arm-task-oriented-manipulation/mankey/config/mugs_20201210.txt' # Construct the database database = SpartanSupervisedKeypointDatabase(db_config) # Construct torch dataset config = SupervisedKeypointDatasetConfig() config.network_in_patch_width = 256 config.network_in_patch_height = 256 config.network_out_map_width = 64 config.network_out_map_height = 64 config.image_database_list.append(database) config.is_train = is_train dataset = SupervisedKeypointDataset(config) return dataset, config
def construct_dataset( is_train: bool ) -> (torch.utils.data.Dataset, SupervisedKeypointDatasetConfig): # Construct the db db_config = SpartanSupvervisedKeypointDBConfig() db_config.keypoint_yaml_name = 'shoe_6_keypoint_image.yaml' db_config.pdc_data_root = '/home/wei/data/pdc' db_config.config_file_path = '/home/wei/Coding/mankey/config/boot_logs.txt' database = SpartanSupervisedKeypointDatabase(db_config) # Construct torch dataset config = SupervisedKeypointDatasetConfig() config.network_in_patch_width = 256 config.network_in_patch_height = 256 config.network_out_map_width = 64 config.network_out_map_height = 64 config.image_database_list.append(database) config.is_train = is_train dataset = SupervisedKeypointDataset(config) return dataset, config
def main(): # Construct the db db_config = SpartanSupvervisedKeypointDBConfig() db_config.keypoint_yaml_name = 'shoe_6_keypoint_image.yaml' db_config.pdc_data_root = '/home/wei/data/pdc' db_config.config_file_path = '/home/wei/Coding/mankey/config/shoe_logs.txt' database = SpartanSupervisedKeypointDatabase(db_config) # Construct torch dataset config = SupervisedKeypointDatasetConfig() config.network_in_patch_height = 256 config.network_in_patch_width = 256 config.network_out_map_width = 64 config.network_out_map_height = 64 config.image_database_list.append(database) config.is_train = False dataset = SupervisedKeypointDataset(config) # The counter min_keypoint_depth: float = 10000 max_keypoint_depth: float = 0 avg_keypoint_depth: float = 0 counter = 0 # Iterate for i in range(len(dataset)): processed_entry = dataset.get_processed_entry(dataset.entry_list[i]) n_keypoint = processed_entry.keypoint_xy_depth.shape[1] for j in range(n_keypoint): min_keypoint_depth = min(min_keypoint_depth, processed_entry.keypoint_xy_depth[2, j]) max_keypoint_depth = max(max_keypoint_depth, processed_entry.keypoint_xy_depth[2, j]) avg_keypoint_depth += processed_entry.keypoint_xy_depth[2, j] counter += 1 # Some output print('The min value is ', min_keypoint_depth) print('The max value is ', max_keypoint_depth) print('The average value is ', avg_keypoint_depth / float(counter))
def save_loaded_img(): from mankey.dataproc.spartan_supervised_db import SpartanSupvervisedKeypointDBConfig, SpartanSupervisedKeypointDatabase # Construct the db db_config = SpartanSupvervisedKeypointDBConfig() db_config.keypoint_yaml_name = 'shoe_6_keypoint_image.yaml' db_config.pdc_data_root = '/home/wei/data/pdc' db_config.config_file_path = '/home/wei/Coding/mankey/config/boot_logs.txt' database = SpartanSupervisedKeypointDatabase(db_config) # Construct torch dataset config = SupervisedKeypointDatasetConfig() config.network_in_patch_width = 256 config.network_in_patch_height = 256 config.network_out_map_width = 64 config.network_out_map_height = 64 config.image_database_list.append(database) config.is_train = False dataset = SupervisedKeypointDataset(config) # Simple check import os print(len(dataset)) tmp_dir = os.path.join(os.path.dirname(__file__), 'tmp') if not os.path.exists(tmp_dir): os.mkdir(tmp_dir) # Save all the warped image from mankey.utils.imgproc import draw_image_keypoint, draw_visible_heatmap, get_visible_mask for i in range(min(1000, len(dataset))): idx = random.randint(0, len(dataset) - 1) processed_entry = dataset.get_processed_entry(dataset.entry_list[idx]) rgb_keypoint = draw_image_keypoint(processed_entry.cropped_rgb, processed_entry.keypoint_xy_depth, processed_entry.keypoint_validity) cv2.imwrite(os.path.join(tmp_dir, 'image_%d_rgb.png' % idx), rgb_keypoint) cv2.imwrite(os.path.join(tmp_dir, 'mask_image_%d_rgb.png' % idx), get_visible_mask(processed_entry.cropped_binary_mask))
def construct_dataset( is_train: bool ) -> (torch.utils.data.Dataset, SupervisedKeypointDatasetConfig): # Construct the db db_config = SpartanSupvervisedKeypointDBConfig() #db_config.keypoint_yaml_name = 'mug_3_keypoint_image.yaml' db_config.keypoint_yaml_name = 'peg_in_hole.yaml' db_config.pdc_data_root = '/tmp2/r09944001/data/pdc' if is_train: db_config.config_file_path = '/tmp2/r09944001/robot-peg-in-hole-task/mankey/config/box_insertion_fix_20210529.txt' else: db_config.config_file_path = '/tmp2/r09944001/robot-peg-in-hole-task/mankey/config/box_insertion_fix_20210529.txt' database = SpartanSupervisedKeypointDatabase(db_config) # Construct torch dataset config = SupervisedKeypointDatasetConfig() config.network_in_patch_width = 256 config.network_in_patch_height = 256 config.network_out_map_width = 64 config.network_out_map_height = 64 config.image_database_list.append(database) config.is_train = is_train dataset = SupervisedKeypointDataset(config) return dataset, config