default=None, help='directory of the dataset to subsample') parser.add_argument('output_path', type=str, default=None, help='directory to store the subsampled dataset') args = parser.parse_args() dataset_path = args.dataset_path output_path = args.output_path dataset = TensorDataset.open(dataset_path) out_dataset = TensorDataset(output_path, dataset.config) ind = np.arange(dataset.num_datapoints) np.random.shuffle(ind) for i, j in enumerate(ind): logging.info('Saving datapoint %d' % (i)) datapoint = dataset[j] out_dataset.add(datapoint) out_dataset.flush() for split_name in dataset.split_names: _, val_indices, _ = dataset.split(split_name) new_val_indices = [] for i in range(ind.shape[0]): if ind[i] in val_indices: new_val_indices.append(i) out_dataset.make_split(split_name, val_indices=new_val_indices)
def generate_segmask_dataset(output_dataset_path, config, save_tensors=True, warm_start=False): """ Generate a segmentation training dataset Parameters ---------- dataset_path : str path to store the dataset config : dict dictionary-like objects containing parameters of the simulator and visualization save_tensors : bool save tensor datasets (for recreating state) warm_start : bool restart dataset generation from a previous state """ # read subconfigs dataset_config = config['dataset'] image_config = config['images'] vis_config = config['vis'] # debugging debug = config['debug'] if debug: np.random.seed(SEED) # read general parameters num_states = config['num_states'] num_images_per_state = config['num_images_per_state'] states_per_flush = config['states_per_flush'] states_per_garbage_collect = config['states_per_garbage_collect'] # set max obj per state max_objs_per_state = config['state_space']['heap']['max_objs'] # read image parameters im_height = config['state_space']['camera']['im_height'] im_width = config['state_space']['camera']['im_width'] segmask_channels = max_objs_per_state + 1 # create the dataset path and all subfolders if they don't exist if not os.path.exists(output_dataset_path): os.mkdir(output_dataset_path) image_dir = os.path.join(output_dataset_path, 'images') if not os.path.exists(image_dir): os.mkdir(image_dir) color_dir = os.path.join(image_dir, 'color_ims') if image_config['color'] and not os.path.exists(color_dir): os.mkdir(color_dir) depth_dir = os.path.join(image_dir, 'depth_ims') if image_config['depth'] and not os.path.exists(depth_dir): os.mkdir(depth_dir) amodal_dir = os.path.join(image_dir, 'amodal_masks') if image_config['amodal'] and not os.path.exists(amodal_dir): os.mkdir(amodal_dir) modal_dir = os.path.join(image_dir, 'modal_masks') if image_config['modal'] and not os.path.exists(modal_dir): os.mkdir(modal_dir) semantic_dir = os.path.join(image_dir, 'semantic_masks') if image_config['semantic'] and not os.path.exists(semantic_dir): os.mkdir(semantic_dir) # setup logging experiment_log_filename = os.path.join(output_dataset_path, 'dataset_generation.log') if os.path.exists(experiment_log_filename) and not warm_start: os.remove(experiment_log_filename) Logger.add_log_file(logger, experiment_log_filename, global_log_file=True) config.save( os.path.join(output_dataset_path, 'dataset_generation_params.yaml')) metadata = {} num_prev_states = 0 # set dataset params if save_tensors: # read dataset subconfigs state_dataset_config = dataset_config['states'] image_dataset_config = dataset_config['images'] state_tensor_config = state_dataset_config['tensors'] image_tensor_config = image_dataset_config['tensors'] obj_pose_dim = POSE_DIM * max_objs_per_state obj_com_dim = POINT_DIM * max_objs_per_state state_tensor_config['fields']['obj_poses']['height'] = obj_pose_dim state_tensor_config['fields']['obj_coms']['height'] = obj_com_dim state_tensor_config['fields']['obj_ids']['height'] = max_objs_per_state image_tensor_config['fields']['camera_pose']['height'] = POSE_DIM if image_config['color']: image_tensor_config['fields']['color_im'] = { 'dtype': 'uint8', 'channels': 3, 'height': im_height, 'width': im_width } if image_config['depth']: image_tensor_config['fields']['depth_im'] = { 'dtype': 'float32', 'channels': 1, 'height': im_height, 'width': im_width } if image_config['modal']: image_tensor_config['fields']['modal_segmasks'] = { 'dtype': 'uint8', 'channels': segmask_channels, 'height': im_height, 'width': im_width } if image_config['amodal']: image_tensor_config['fields']['amodal_segmasks'] = { 'dtype': 'uint8', 'channels': segmask_channels, 'height': im_height, 'width': im_width } if image_config['semantic']: image_tensor_config['fields']['semantic_segmasks'] = { 'dtype': 'uint8', 'channels': 1, 'height': im_height, 'width': im_width } # create dataset filenames state_dataset_path = os.path.join(output_dataset_path, 'state_tensors') image_dataset_path = os.path.join(output_dataset_path, 'image_tensors') if warm_start: if not os.path.exists(state_dataset_path) or not os.path.exists( image_dataset_path): logger.error( 'Attempting to warm start without saved tensor dataset') exit(1) # open datasets logger.info('Opening state dataset') state_dataset = TensorDataset.open(state_dataset_path, access_mode='READ_WRITE') logger.info('Opening image dataset') image_dataset = TensorDataset.open(image_dataset_path, access_mode='READ_WRITE') # read configs state_tensor_config = state_dataset.config image_tensor_config = image_dataset.config # clean up datasets (there may be datapoints with indices corresponding to non-existent data) num_state_datapoints = state_dataset.num_datapoints num_image_datapoints = image_dataset.num_datapoints num_prev_states = num_state_datapoints # clean up images image_ind = num_image_datapoints - 1 image_datapoint = image_dataset[image_ind] while image_ind > 0 and image_datapoint[ 'state_ind'] >= num_state_datapoints: image_ind -= 1 image_datapoint = image_dataset[image_ind] images_to_remove = num_image_datapoints - 1 - image_ind logger.info('Deleting last %d image tensors' % (images_to_remove)) if images_to_remove > 0: image_dataset.delete_last(images_to_remove) num_image_datapoints = image_dataset.num_datapoints else: # create datasets from scratch logger.info('Creating datasets') state_dataset = TensorDataset(state_dataset_path, state_tensor_config) image_dataset = TensorDataset(image_dataset_path, image_tensor_config) # read templates state_datapoint = state_dataset.datapoint_template image_datapoint = image_dataset.datapoint_template if warm_start: if not os.path.exists( os.path.join(output_dataset_path, 'metadata.json')): logger.error( 'Attempting to warm start without previously created dataset') exit(1) # Read metadata and indices metadata = json.load( open(os.path.join(output_dataset_path, 'metadata.json'), 'r')) test_inds = np.load(os.path.join(image_dir, 'test_indices.npy')).tolist() train_inds = np.load(os.path.join(image_dir, 'train_indices.npy')).tolist() # set obj ids and splits reverse_obj_ids = metadata['obj_ids'] obj_id_map = utils.reverse_dictionary(reverse_obj_ids) obj_splits = metadata['obj_splits'] obj_keys = obj_splits.keys() mesh_filenames = metadata['meshes'] # Get list of images generated so far generated_images = sorted( os.listdir(color_dir)) if image_config['color'] else sorted( os.listdir(depth_dir)) num_total_images = len(generated_images) # Do our own calculation if no saved tensors if num_prev_states == 0: num_prev_states = num_total_images // num_images_per_state # Find images to remove and remove them from all relevant places if they exist num_images_to_remove = num_total_images - (num_prev_states * num_images_per_state) logger.info( 'Deleting last {} invalid images'.format(num_images_to_remove)) for k in range(num_images_to_remove): im_name = generated_images[-(k + 1)] im_basename = os.path.splitext(im_name)[0] im_ind = int(im_basename.split('_')[1]) if os.path.exists(os.path.join(depth_dir, im_name)): os.remove(os.path.join(depth_dir, im_name)) if os.path.exists(os.path.join(color_dir, im_name)): os.remove(os.path.join(color_dir, im_name)) if os.path.exists(os.path.join(semantic_dir, im_name)): os.remove(os.path.join(semantic_dir, im_name)) if os.path.exists(os.path.join(modal_dir, im_basename)): shutil.rmtree(os.path.join(modal_dir, im_basename)) if os.path.exists(os.path.join(amodal_dir, im_basename)): shutil.rmtree(os.path.join(amodal_dir, im_basename)) if im_ind in train_inds: train_inds.remove(im_ind) elif im_ind in test_inds: test_inds.remove(im_ind) else: # Create initial env to generate metadata env = BinHeapEnv(config) obj_id_map = env.state_space.obj_id_map obj_keys = env.state_space.obj_keys obj_splits = env.state_space.obj_splits mesh_filenames = env.state_space.mesh_filenames save_obj_id_map = obj_id_map.copy() save_obj_id_map[ENVIRONMENT_KEY] = np.iinfo(np.uint32).max reverse_obj_ids = utils.reverse_dictionary(save_obj_id_map) metadata['obj_ids'] = reverse_obj_ids metadata['obj_splits'] = obj_splits metadata['meshes'] = mesh_filenames json.dump(metadata, open(os.path.join(output_dataset_path, 'metadata.json'), 'w'), indent=JSON_INDENT, sort_keys=True) train_inds = [] test_inds = [] # generate states and images state_id = num_prev_states while state_id < num_states: # create env and set objects create_start = time.time() env = BinHeapEnv(config) env.state_space.obj_id_map = obj_id_map env.state_space.obj_keys = obj_keys env.state_space.set_splits(obj_splits) env.state_space.mesh_filenames = mesh_filenames create_stop = time.time() logger.info('Creating env took %.3f sec' % (create_stop - create_start)) # sample states states_remaining = num_states - state_id for i in range(min(states_per_garbage_collect, states_remaining)): # log current rollout if state_id % config['log_rate'] == 0: logger.info('State: %06d' % (state_id)) try: # reset env env.reset() state = env.state split = state.metadata['split'] # render state if vis_config['state']: env.view_3d_scene() # Save state if desired if save_tensors: # set obj state variables obj_pose_vec = np.zeros(obj_pose_dim) obj_com_vec = np.zeros(obj_com_dim) obj_id_vec = np.iinfo( np.uint32).max * np.ones(max_objs_per_state) j = 0 for obj_state in state.obj_states: obj_pose_vec[j * POSE_DIM:(j + 1) * POSE_DIM] = obj_state.pose.vec obj_com_vec[j * POINT_DIM:(j + 1) * POINT_DIM] = obj_state.center_of_mass obj_id_vec[j] = int(obj_id_map[obj_state.key]) j += 1 # store datapoint env params state_datapoint['state_id'] = state_id state_datapoint['obj_poses'] = obj_pose_vec state_datapoint['obj_coms'] = obj_com_vec state_datapoint['obj_ids'] = obj_id_vec state_datapoint['split'] = split # store state datapoint image_start_ind = image_dataset.num_datapoints image_end_ind = image_start_ind + num_images_per_state state_datapoint['image_start_ind'] = image_start_ind state_datapoint['image_end_ind'] = image_end_ind # clean up del obj_pose_vec del obj_com_vec del obj_id_vec # add state state_dataset.add(state_datapoint) # render images for k in range(num_images_per_state): # reset the camera if num_images_per_state > 1: env.reset_camera() obs = env.render_camera_image(color=image_config['color']) if image_config['color']: color_obs, depth_obs = obs else: depth_obs = obs # vis obs if vis_config['obs']: if image_config['depth']: plt.figure() plt.imshow(depth_obs) plt.title('Depth Observation') if image_config['color']: plt.figure() plt.imshow(color_obs) plt.title('Color Observation') plt.show() if image_config['modal'] or image_config[ 'amodal'] or image_config['semantic']: # render segmasks amodal_segmasks, modal_segmasks = env.render_segmentation_images( ) # retrieve segmask data modal_segmask_arr = np.iinfo(np.uint8).max * np.ones( [im_height, im_width, segmask_channels], dtype=np.uint8) amodal_segmask_arr = np.iinfo(np.uint8).max * np.ones( [im_height, im_width, segmask_channels], dtype=np.uint8) stacked_segmask_arr = np.zeros( [im_height, im_width, 1], dtype=np.uint8) modal_segmask_arr[:, :, :env. num_objects] = modal_segmasks amodal_segmask_arr[:, :, :env. num_objects] = amodal_segmasks if image_config['semantic']: for j in range(env.num_objects): this_obj_px = np.where( modal_segmasks[:, :, j] > 0) stacked_segmask_arr[this_obj_px[0], this_obj_px[1], 0] = j + 1 # visualize if vis_config['semantic']: plt.figure() plt.imshow(stacked_segmask_arr.squeeze()) plt.show() if save_tensors: # save image data as tensors if image_config['color']: image_datapoint['color_im'] = color_obs if image_config['depth']: image_datapoint['depth_im'] = depth_obs[:, :, None] if image_config['modal']: image_datapoint[ 'modal_segmasks'] = modal_segmask_arr if image_config['amodal']: image_datapoint[ 'amodal_segmasks'] = amodal_segmask_arr if image_config['semantic']: image_datapoint[ 'semantic_segmasks'] = stacked_segmask_arr image_datapoint['camera_pose'] = env.camera.pose.vec image_datapoint[ 'camera_intrs'] = env.camera.intrinsics.vec image_datapoint['state_ind'] = state_id image_datapoint['split'] = split # add image image_dataset.add(image_datapoint) # Save depth image and semantic masks if image_config['color']: ColorImage(color_obs).save( os.path.join( color_dir, 'image_{:06d}.png'.format( num_images_per_state * state_id + k))) if image_config['depth']: DepthImage(depth_obs).save( os.path.join( depth_dir, 'image_{:06d}.png'.format( num_images_per_state * state_id + k))) if image_config['modal']: modal_id_dir = os.path.join( modal_dir, 'image_{:06d}'.format(num_images_per_state * state_id + k)) if not os.path.exists(modal_id_dir): os.mkdir(modal_id_dir) for i in range(env.num_objects): BinaryImage(modal_segmask_arr[:, :, i]).save( os.path.join(modal_id_dir, 'channel_{:03d}.png'.format(i))) if image_config['amodal']: amodal_id_dir = os.path.join( amodal_dir, 'image_{:06d}'.format(num_images_per_state * state_id + k)) if not os.path.exists(amodal_id_dir): os.mkdir(amodal_id_dir) for i in range(env.num_objects): BinaryImage(amodal_segmask_arr[:, :, i]).save( os.path.join(amodal_id_dir, 'channel_{:03d}.png'.format(i))) if image_config['semantic']: GrayscaleImage(stacked_segmask_arr.squeeze()).save( os.path.join( semantic_dir, 'image_{:06d}.png'.format( num_images_per_state * state_id + k))) # Save split if split == TRAIN_ID: train_inds.append(num_images_per_state * state_id + k) else: test_inds.append(num_images_per_state * state_id + k) # auto-flush after every so many timesteps if state_id % states_per_flush == 0: np.save(os.path.join(image_dir, 'train_indices.npy'), train_inds) np.save(os.path.join(image_dir, 'test_indices.npy'), test_inds) if save_tensors: state_dataset.flush() image_dataset.flush() # delete action objects for obj_state in state.obj_states: del obj_state del state gc.collect() # update state id state_id += 1 except Exception as e: # log an error logger.warning('Heap failed!') logger.warning('%s' % (str(e))) logger.warning(traceback.print_exc()) if debug: raise del env gc.collect() env = BinHeapEnv(config) env.state_space.obj_id_map = obj_id_map env.state_space.obj_keys = obj_keys env.state_space.set_splits(obj_splits) env.state_space.mesh_filenames = mesh_filenames # garbage collect del env gc.collect() # write all datasets to file, save indices np.save(os.path.join(image_dir, 'train_indices.npy'), train_inds) np.save(os.path.join(image_dir, 'test_indices.npy'), test_inds) if save_tensors: state_dataset.flush() image_dataset.flush() logger.info('Generated %d image datapoints' % (state_id * num_images_per_state))
def test_single_read_write(self): # seed np.random.seed(SEED) random.seed(SEED) # open dataset create_successful = True try: dataset = TensorDataset(TEST_TENSOR_DATASET_NAME, TENSOR_CONFIG) except: create_successful = False self.assertTrue(create_successful) # check field names write_datapoint = dataset.datapoint_template for field_name in write_datapoint.keys(): self.assertTrue(field_name in dataset.field_names) # add the datapoint write_datapoint['float_value'] = np.random.rand() write_datapoint['int_value'] = int(100 * np.random.rand()) write_datapoint['str_value'] = utils.gen_experiment_id() write_datapoint['vector_value'] = np.random.rand(HEIGHT) write_datapoint['matrix_value'] = np.random.rand(HEIGHT, WIDTH) write_datapoint['image_value'] = np.random.rand( HEIGHT, WIDTH, CHANNELS) dataset.add(write_datapoint) # check num datapoints self.assertTrue(dataset.num_datapoints == 1) # add metadata metadata_num = np.random.rand() dataset.add_metadata('test', metadata_num) # check written arrays dataset.flush() for field_name in dataset.field_names: filename = os.path.join(TEST_TENSOR_DATASET_NAME, 'tensors', '%s_00000.npz' % (field_name)) value = np.load(filename)['arr_0'] if isinstance(value[0], str): self.assertTrue(value[0] == write_datapoint[field_name]) else: self.assertTrue( np.allclose(value[0], write_datapoint[field_name])) # re-open the dataset del dataset dataset = TensorDataset.open(TEST_TENSOR_DATASET_NAME) # read metadata self.assertTrue(np.allclose(dataset.metadata['test'], metadata_num)) # read datapoint read_datapoint = dataset.datapoint(0) for field_name in dataset.field_names: if isinstance(read_datapoint[field_name], str): self.assertTrue( read_datapoint[field_name] == write_datapoint[field_name]) else: self.assertTrue( np.allclose(read_datapoint[field_name], write_datapoint[field_name])) # check iterator for read_datapoint in dataset: for field_name in dataset.field_names: if isinstance(read_datapoint[field_name], str): self.assertTrue(read_datapoint[field_name] == write_datapoint[field_name]) else: self.assertTrue( np.allclose(read_datapoint[field_name], write_datapoint[field_name])) # read individual fields for field_name in dataset.field_names: read_datapoint = dataset.datapoint(0, field_names=[field_name]) if isinstance(read_datapoint[field_name], str): self.assertTrue( read_datapoint[field_name] == write_datapoint[field_name]) else: self.assertTrue( np.allclose(read_datapoint[field_name], write_datapoint[field_name])) # re-open the dataset in write-only del dataset dataset = TensorDataset.open(TEST_TENSOR_DATASET_NAME, access_mode=READ_WRITE_ACCESS) # delete datapoint dataset.delete_last() # check that the dataset is correct self.assertTrue(dataset.num_datapoints == 0) self.assertTrue(dataset.num_tensors == 0) for field_name in dataset.field_names: filename = os.path.join(TEST_TENSOR_DATASET_NAME, 'tensors', '%s_00000.npz' % (field_name)) self.assertFalse(os.path.exists(filename)) # remove dataset if os.path.exists(TEST_TENSOR_DATASET_NAME): shutil.rmtree(TEST_TENSOR_DATASET_NAME)
def run_parallel_bin_picking_benchmark(input_dataset_path, heap_ids, timesteps, output_dataset_path, config_filename): raise NotImplementedError('Cannot run in parallel. Need to split up the heap ids and timesteps') # load config config = YamlConfig(config_filename) # init ray ray_config = config['ray'] num_cpus = ray_config['num_cpus'] ray.init(num_cpus=num_cpus, redirect_output=ray_config['redirect_output']) # rollouts num_rollouts = config['num_rollouts'] // num_cpus dataset_ids = [rollout_bin_picking_policy_in_parallel.remote(dataset_path, config_filename, num_rollouts) for i in range(num_cpus)] dataset_filenames = ray.get(dataset_ids) if len(dataset_filenames) == 0: return # merge datasets subproc_dataset = TensorDataset.open(dataset_filenames[0]) tensor_config = subproc_dataset.config # open dataset dataset = TensorDataset(dataset_path, tensor_config) dataset.add_metadata('action_ids', subproc_dataset.metadata['action_ids']) # add datapoints obj_id = 0 heap_id = 0 obj_ids = {} for dataset_filename in dataset_filenames: logging.info('Aggregating data from %s' %(dataset_filename)) j = 0 subproc_dataset = TensorDataset.open(dataset_filename) subproc_obj_ids = subproc_dataset.metadata['obj_ids'] for datapoint in subproc_dataset: if j > 0 and datapoint['timesteps'] == 0: heap_id += 1 # modify object ids for i in range(datapoint['obj_ids'].shape[0]): subproc_obj_id = datapoint['obj_ids'][i] if subproc_obj_id != np.uint32(-1): subproc_obj_key = subproc_obj_ids[str(subproc_obj_id)] if subproc_obj_key not in obj_ids.keys(): obj_ids[subproc_obj_key] = obj_id obj_id += 1 datapoint['obj_ids'][i] = obj_ids[subproc_obj_key] # modify grasped obj id subproc_grasped_obj_id = datapoint['grasped_obj_ids'] grasped_obj_key = subproc_obj_ids[str(subproc_grasped_obj_id)] datapoint['grasped_obj_ids'] = obj_ids[grasped_obj_key] # modify heap id datapoint['heap_ids'] = heap_id # add datapoint to dataset dataset.add(datapoint) j += 1 # write to disk obj_ids = utils.reverse_dictionary(obj_ids) dataset.add_metadata('obj_ids', obj_ids) dataset.flush()
def test_multi_tensor_read_write(self): # seed np.random.seed(SEED) random.seed(SEED) # open dataset dataset = TensorDataset(TEST_TENSOR_DATASET_NAME, TENSOR_CONFIG) write_datapoints = [] for i in range(DATAPOINTS_PER_FILE + 1): write_datapoint = {} write_datapoint['float_value'] = np.random.rand() write_datapoint['int_value'] = int(100 * np.random.rand()) write_datapoint['str_value'] = utils.gen_experiment_id() write_datapoint['vector_value'] = np.random.rand(HEIGHT) write_datapoint['matrix_value'] = np.random.rand(HEIGHT, WIDTH) write_datapoint['image_value'] = np.random.rand( HEIGHT, WIDTH, CHANNELS) dataset.add(write_datapoint) write_datapoints.append(write_datapoint) # check num datapoints self.assertTrue(dataset.num_datapoints == DATAPOINTS_PER_FILE + 1) self.assertTrue(dataset.num_tensors == 2) # check read dataset.flush() del dataset dataset = TensorDataset.open(TEST_TENSOR_DATASET_NAME, access_mode=READ_WRITE_ACCESS) for i, read_datapoint in enumerate(dataset): write_datapoint = write_datapoints[i] for field_name in dataset.field_names: if isinstance(read_datapoint[field_name], str): self.assertTrue(read_datapoint[field_name] == write_datapoint[field_name]) else: self.assertTrue( np.allclose(read_datapoint[field_name], write_datapoint[field_name])) for i, read_datapoint in enumerate(dataset): # check iterator item write_datapoint = write_datapoints[i] for field_name in dataset.field_names: if isinstance(read_datapoint[field_name], str): self.assertTrue(read_datapoint[field_name] == write_datapoint[field_name]) else: self.assertTrue( np.allclose(read_datapoint[field_name], write_datapoint[field_name])) # check random item ind = np.random.choice(dataset.num_datapoints) write_datapoint = write_datapoints[ind] read_datapoint = dataset.datapoint(ind) for field_name in dataset.field_names: if isinstance(read_datapoint[field_name], str): self.assertTrue(read_datapoint[field_name] == write_datapoint[field_name]) else: self.assertTrue( np.allclose(read_datapoint[field_name], write_datapoint[field_name])) # check deletion dataset.delete_last() self.assertTrue(dataset.num_datapoints == DATAPOINTS_PER_FILE) self.assertTrue(dataset.num_tensors == 1) for field_name in dataset.field_names: filename = os.path.join(TEST_TENSOR_DATASET_NAME, 'tensors', '%s_00001.npz' % (field_name)) dataset.add(write_datapoints[-1]) for write_datapoint in write_datapoints: dataset.add(write_datapoint) self.assertTrue(dataset.num_datapoints == 2 * (DATAPOINTS_PER_FILE + 1)) self.assertTrue(dataset.num_tensors == 3) # check valid for i in range(dataset.num_datapoints): read_datapoint = dataset.datapoint(i) write_datapoint = write_datapoints[i % (len(write_datapoints))] for field_name in dataset.field_names: if isinstance(read_datapoint[field_name], str): self.assertTrue(read_datapoint[field_name] == write_datapoint[field_name]) else: self.assertTrue( np.allclose(read_datapoint[field_name], write_datapoint[field_name])) # check read then write out of order ind = np.random.choice(DATAPOINTS_PER_FILE) write_datapoint = write_datapoints[ind] read_datapoint = dataset.datapoint(ind) for field_name in dataset.field_names: if isinstance(read_datapoint[field_name], str): self.assertTrue( read_datapoint[field_name] == write_datapoint[field_name]) else: self.assertTrue( np.allclose(read_datapoint[field_name], write_datapoint[field_name])) write_datapoint = write_datapoints[0] dataset.add(write_datapoint) read_datapoint = dataset.datapoint(dataset.num_datapoints - 1) for field_name in dataset.field_names: if isinstance(read_datapoint[field_name], str): self.assertTrue( read_datapoint[field_name] == write_datapoint[field_name]) else: self.assertTrue( np.allclose(read_datapoint[field_name], write_datapoint[field_name])) dataset.delete_last() # check data integrity for i, read_datapoint in enumerate(dataset): write_datapoint = write_datapoints[i % len(write_datapoints)] for field_name in dataset.field_names: if isinstance(read_datapoint[field_name], str): self.assertTrue(read_datapoint[field_name] == write_datapoint[field_name]) else: self.assertTrue( np.allclose(read_datapoint[field_name], write_datapoint[field_name])) # delete last dataset.delete_last(len(write_datapoints)) self.assertTrue(dataset.num_datapoints == DATAPOINTS_PER_FILE + 1) self.assertTrue(dataset.num_tensors == 2) for i, read_datapoint in enumerate(dataset): write_datapoint = write_datapoints[i] for field_name in dataset.field_names: if isinstance(read_datapoint[field_name], str): self.assertTrue(read_datapoint[field_name] == write_datapoint[field_name]) else: self.assertTrue( np.allclose(read_datapoint[field_name], write_datapoint[field_name])) # remove dataset if os.path.exists(TEST_TENSOR_DATASET_NAME): shutil.rmtree(TEST_TENSOR_DATASET_NAME)