def test_pose_extractors(sim): scene_filepath = "" pose_extractor_names = ["closest_point_extractor", "panorama_extractor"] for name in pose_extractor_names: extractor = ImageExtractor( scene_filepath, img_size=(32, 32), sim=sim, pose_extractor_name=name ) assert len(extractor) > 1
def test_extractor_all_modes(sim): scene_filepath = "" methods = ["closest", "panorama"] for method in methods: extractor = ImageExtractor(scene_filepath, img_size=(32, 32), sim=sim, extraction_method=method) assert len(extractor) > 1
def test_pose_extractors(make_cfg_settings): with habitat_sim.Simulator(make_cfg(make_cfg_settings)) as sim: scene_filepath = "" pose_extractor_names = [ "closest_point_extractor", "panorama_extractor" ] for name in pose_extractor_names: extractor = ImageExtractor(scene_filepath, img_size=(32, 32), sim=sim, pose_extractor_name=name) assert len(extractor) > 1
def test_data_extractor_end_to_end(sim): # Path is relative to simulator.py scene_filepath = "" extractor = ImageExtractor(scene_filepath, labels=[0.0], img_size=(32, 32), sim=sim) dataset = MyDataset(extractor) dataloader = DataLoader(dataset, batch_size=3) net = TrivialNet() # Run data through network for sample_batch in dataloader: img, _ = sample_batch["rgba"], sample_batch["label"] img = img.permute(0, 3, 2, 1).float() net(img)
}) ann_id += 1 with open(os.path.join(args.output_dir, "annotations", annotations_name), "w+") as f: json.dump(dataset, f) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--output_dir", default="./dataset") parser.add_argument("--scene", type=str) parser.add_argument("--extraction_method", default="panorama") parser.add_argument("--split", default=0.8, type=float) args = parser.parse_args() try: os.mkdir(args.output_dir) os.mkdir(os.path.join(args.output_dir, "annotations")) os.mkdir(os.path.join(args.output_dir, "images")) except BaseException as e: print(e) extractor = ImageExtractor(args.scene, labels=[0.0], img_size=(512, 512), output=["rgba", "semantic"], extraction_method="panorama") generate_set(0, int(len(extractor) * args.split), extractor, "instances_train.json") generate_set(int(len(extractor) * args.split), int(len(extractor) * (1 - args.split)), extractor, "instances_val.json")