) load_checkpoint(checkpoint_path, model) model = model.eval() device = torch.device("cuda:0") model = model.to(device) valid_mask = np.load( f"{DATA_DIR}/scenes/validate_chopped_100/mask.npz")["arr_0"] dm = LocalDataManager(DATA_DIR) rasterizer = build_rasterizer(cfg, dm) valid_zarr = ChunkedDataset( dm.require("scenes/validate_chopped_100/validate.zarr")).open() bs = 32 valid_dataset = AgentDataset(cfg, valid_zarr, rasterizer, agents_mask=valid_mask) print(len(valid_dataset)) # valid_dataset = Subset(valid_dataset, list(range(bs * 4))) valid_dataloader = DataLoader( valid_dataset, shuffle=False, batch_size=bs,
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") model = build_model(training_cfg).to(device) optimizer = optim.Adam(model.parameters(), lr=1e-3) criterion = nn.MSELoss(reduction="none") print("Model Loaded. Loading training dataset...") print("==================================TRAIN DATA==================================") # training cfg train_cfg = training_cfg["train_data_loader"] # rasterizer rasterizer = build_rasterizer(training_cfg, dm) # dataloader train_zarr = ChunkedDataset(dm.require(train_cfg["key"])).open() train_dataset = AgentDataset(training_cfg, train_zarr, rasterizer) train_dataloader = DataLoader(train_dataset, shuffle=train_cfg["shuffle"], batch_size=train_cfg["batch_size"], num_workers=train_cfg["num_workers"]) print(train_dataset) # ==== TRAIN LOOP tr_it = iter(train_dataloader) progress_bar = tqdm(range(training_cfg["train_params"]["max_num_steps"])) losses_train = [] for _ in progress_bar: try: data = next(tr_it) except StopIteration: tr_it = iter(train_dataloader) data = next(tr_it)
valid_cfg = cfg["valid_data_loader"] # valid_path = "scenes/sample.zarr" if debug else valid_cfg["key"] valid_path = valid_cfg["key"] valid_agents_mask = None if flags.validation_chopped: num_frames_to_chop = 100 th_agent_prob = cfg["raster_params"]["filter_agents_threshold"] min_frame_future = 1 num_frames_to_copy = num_frames_to_chop valid_agents_mask = load_mask_chopped(dm.require(valid_path), th_agent_prob, num_frames_to_copy, min_frame_future) print("valid_path", valid_path, "valid_agents_mask", valid_agents_mask.shape) valid_zarr = ChunkedDataset(dm.require(valid_path)).open(cached=False) valid_agent_dataset = FasterAgentDataset( cfg, valid_zarr, rasterizer, agents_mask=valid_agents_mask, min_frame_history=flags.min_frame_history, min_frame_future=flags.min_frame_future, override_sample_function_name=flags.override_sample_function_name, ) # valid_dataset = TransformDataset(valid_agent_dataset, transform) valid_dataset = valid_agent_dataset # Only use `n_valid_data` dataset for fast check. # Sample dataset from regular interval, to increase variety/coverage n_valid_data = 150 if debug else -1
help="zarr path") parser.add_argument("--th_agent_prob", type=float, required=True, help="perception threshold on agents") parser.add_argument("--th_yaw_degree", type=float, default=TH_YAW_DEGREE, help="max absolute distance in degree") parser.add_argument("--th_extent_ratio", type=float, default=TH_EXTENT_RATIO, help="max change in area allowed") parser.add_argument("--th_distance_av", type=float, default=TH_DISTANCE_AV, help="max distance from AV in meters") args = parser.parse_args() for input_folder in args.input_folders: zarr_dataset = ChunkedDataset(path=input_folder) zarr_dataset.open() select_agents( zarr_dataset, args.th_agent_prob, args.th_yaw_degree, args.th_extent_ratio, args.th_distance_av, )
def main(): cfg = { 'save_path': "./regnetx/", 'seed': 39, 'data_path': "kagglepath", 'stem_type':'simple_stem_in', 'stem_w':32, 'block_type':'res_bottleneck_block', 'ds':[2,5,13,1], 'ws':[72,216,576,1512], 'ss':[2,2,2,2], 'bms':[1.0,1.0,1.0,1.0], 'gws':[24,24,24,24], 'se_r':0.25, 'model_params': { 'history_num_frames': 10, #3+20+2=25 25*h*w 'history_step_size': 1, 'history_delta_time': 0.1, 'future_num_frames': 50, # 1512 -> 50*2*3+3=303 nn.linear 'future_step_size': 1, 'future_delta_time': 0.1, 'opt_type' : 'adam', 'lr': 3e-4, 'w_decay': 0, 'reduce_type':'stone', 0 'r_factor': 0.5, 'r_step' : [200_000, 300_000, 360_000, 420_000, 480_000, 540_000], 'weight_path': './050.pth', }, 'raster_params': { 'raster_size': [224, 224], 'pixel_size': [0.5, 0.5], 'ego_center': [0.25, 0.5], 'map_type': 'py_semantic', 'satellite_map_key': 'aerial_map/aerial_map.png', 'semantic_map_key': 'semantic_map/semantic_map.pb', 'dataset_meta_key': 'meta.json', 'filter_agents_threshold': 0.5 }, 'train_data_loader': { 'key': 'scenes/train.zarr', 'batch_size': 16, 'shuffle': True, 'num_workers': 20 }, 'val_data_loader': { 'key': "scenes/validate.zarr", 'batch_size': 16, 'shuffle': False, 'num_workers': 20 }, 'test_data_loader': { 'key': 'scenes/test.zarr', 'batch_size': 32, 'shuffle': False, 'num_workers': 4 }, 'train_params': { 'max_num_steps': 600_000, 'checkpoint_every_n_steps': 100_000, 'eval_every_n_steps' : 100_000, } } writer = self_mkdir(cfg) set_seed(cfg["seed"]) os.environ["L5KIT_DATA_FOLDER"] = cfg["data_path"] dm = LocalDataManager(None) model = LyftMultiModel(cfg)#1 weight_path = cfg["model_params"]["weight_path"] if weight_path: model.load_state_dict(torch.load(weight_path)) else: print('no check points') model.cuda() m_params = cfg["model_params"] if m_params['opt_type'] == 'sgd': optimizer = optim.SGD(model.parameters(), lr=m_params["lr"], weight_decay=m_params['w_decay'],) elif m_params['opt_type'] == 'adam': optimizer = optim.Adam(model.parameters(), lr=m_params["lr"], weight_decay=m_params['w_decay'],) else: assert False, 'cfg opt_type error' if m_params['reduce_type'] == 'stone': lr_sche = optim.lr_scheduler.MultiStepLR(optimizer, m_params['r_step'], gamma=m_params['r_factor'], last_epoch=-1) else: assert False, 'cfg reduce_type error' Training = False if Training: train_cfg = cfg["train_data_loader"] rasterizer = build_rasterizer(cfg, dm) train_zarr = ChunkedDataset(dm.require(train_cfg["key"])).open() train_dataset = AgentDataset(cfg, train_zarr, rasterizer) train_dataloader = DataLoader(train_dataset, shuffle=train_cfg["shuffle"], batch_size=train_cfg["batch_size"], num_workers=train_cfg["num_workers"]) print(train_dataset) tr_it = iter(train_dataloader) progress_bar = tqdm(range(cfg["train_params"]["max_num_steps"])) model_name = cfg["model_params"]["model_name"] first_time = True eval_dataloader = None eval_gt_path = None model.train() torch.set_grad_enabled(True) loss_ten = 0 for i in progress_bar: try: data = next(tr_it) except StopIteration: tr_it = iter(train_dataloader) data = next(tr_it) loss, _, _ = forward(data, model)#2 optimizer.zero_grad() loss.backward() optimizer.step() lr_sche.step() writer.add_scalar('train_loss', loss.item(), i) writer.add_scalar('lr', optimizer.param_groups[0]['lr'], i) if i == 10: loss_ten = loss.item() progress_bar.set_description(f"loss: {loss.item()} and {loss_ten}") if (i+1) % cfg['train_params']['checkpoint_every_n_steps'] == 0: torch.save(model.state_dict(), os.path.join(cfg['save_path'],f'{model_name}_{i}.pth')) if (i+1) % cfg['train_params']['eval_every_n_steps'] == 0: #3 first_time, eval_dataloader, eval_gt_path = evaluate(cfg, model, dm, rasterizer, first_time, i+1, eval_dataloader, eval_gt_path) test_cfg = cfg["test_data_loader"] rasterizer = build_rasterizer(cfg, dm) test_zarr = ChunkedDataset(dm.require(test_cfg["key"])).open() test_mask = np.load(os.path.join(cfg["data_path"],'scenes/mask.npz'))["arr_0"] test_dataset = AgentDataset(cfg, test_zarr, rasterizer, agents_mask=test_mask) test_dataloader = DataLoader(test_dataset,shuffle=test_cfg["shuffle"],batch_size=test_cfg["batch_size"], num_workers=test_cfg["num_workers"]) print(test_dataset) model.eval() torch.set_grad_enabled(False) # store information for evaluation future_coords_offsets_pd = [] timestamps = [] confidences_list = [] agent_ids = [] progress_bar = tqdm(test_dataloader) for data in progress_bar: _, preds, confidences = forward(data, model) #fix for the new environment preds = preds.cpu().numpy() world_from_agents = data["world_from_agent"].numpy() centroids = data["centroid"].numpy() coords_offset = [] # convert into world coordinates and compute offsets for idx in range(len(preds)): for mode in range(3): preds[idx, mode, :, :] = transform_points(preds[idx, mode, :, :], world_from_agents[idx]) - centroids[idx][:2] future_coords_offsets_pd.append(preds.copy()) confidences_list.append(confidences.cpu().numpy().copy()) timestamps.append(data["timestamp"].numpy().copy()) agent_ids.append(data["track_id"].numpy().copy()) pred_path = 'submission.csv' write_pred_csv(pred_path, timestamps=np.concatenate(timestamps), track_ids=np.concatenate(agent_ids), coords=np.concatenate(future_coords_offsets_pd), confs = np.concatenate(confidences_list) )
} print("Load dataset...") default_test_cfg = { 'key': 'scenes/test.zarr', 'batch_size': 32, 'shuffle': False, 'num_workers': 4, } test_cfg = cfg.get("test_data_loader", default_test_cfg) rasterizer = build_rasterizer(cfg, dm) test_path = test_cfg["key"] print(f"Loading from {test_path}") test_zarr = ChunkedDataset(dm.require(test_path)).open() print("test_zarr", type(test_zarr)) test_mask = np.load(f"{DIR_INPUT}/scenes/mask.npz")["arr_0"] test_agent_dataset = AgentDataset(cfg, test_zarr, rasterizer, agents_mask=test_mask) test_dataset = test_agent_dataset test_loader = DataLoader( test_dataset, shuffle=test_cfg["shuffle"], batch_size=test_cfg["batch_size"], num_workers=test_cfg["num_workers"], pin_memory=True, )
def train(model, device, data_path, lr=1e-3, force_iters=None, file_name="resnet.pth"): # set env variable for data os.environ["L5KIT_DATA_FOLDER"] = data_path dm = LocalDataManager(None) # get config cfg = model.cfg print(cfg) # ===== INIT DATASET train_cfg = cfg["train_data_loader"] # Rasterizer rasterizer = build_rasterizer(cfg, dm) # Train dataset/dataloader train_zarr = ChunkedDataset(dm.require(train_cfg["key"])).open() train_dataset = AgentDataset(cfg, train_zarr, rasterizer) train_dataloader = DataLoader(train_dataset, shuffle=train_cfg["shuffle"], batch_size=train_cfg["batch_size"], num_workers=train_cfg["num_workers"]) print(train_dataset) # ==== INIT MODEL parameters optimizer = optim.Adam(model.parameters(), lr=lr) criterion = nn.MSELoss(reduction="none") # ==== TRAIN LOOP if force_iters is None: iterations = cfg["train_params"]["max_num_steps"] else: iterations = force_iters tr_it = iter(train_dataloader) progress_bar = tqdm(range(iterations)) losses_train = [] rolling_avg = [] #torch.save(model.state_dict(), "/home/michael/Workspace/Lyft/model/resnet_base.pth") for i in progress_bar: try: data = next(tr_it) except StopIteration: tr_it = iter(train_dataloader) data = next(tr_it) model.train() torch.set_grad_enabled(True) loss, _, _ = model.forward(data, criterion) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() losses_train.append(loss.item()) rolling_avg.append(np.mean(losses_train)) progress_bar.set_description( f"loss: {loss.item()} loss(avg): {np.mean(losses_train)}") # if i == 10000: # torch.save(model.state_dict(), "/home/michael/Workspace/Lyft/model/resnet" + str(i) + ".pth") print("Done Training") torch.save(model.state_dict(), f"{os.getcwd()}/model/{file_name}") plt.plot(rolling_avg) return model
# The result is that **each scene has been reduced to only 100 frames**, and **only valid agents in the 100th frame will be used to compute the metrics**. Because following frames in the scene have been chopped off, we can't just look ahead to get the future of those agents. # # In this example, we simulate this pipeline by running `chop_dataset` on the validation set. The function stores: # - a new chopped `.zarr` dataset, in which each scene has only the first 100 frames; # - a numpy mask array where only valid agents in the 100th frame are True; # - a ground-truth file with the future coordinates of those agents; # In[47]: eval_cfg = cfg["val_data_loader"] eval_zarr_path = str( Path(eval_base_path) / Path(dm.require(eval_cfg["key"])).name) eval_mask_path = str(Path(eval_base_path) / "mask.npz") eval_gt_path = str(Path(eval_base_path) / "gt.csv") eval_zarr = ChunkedDataset(eval_zarr_path).open() eval_mask = np.load(eval_mask_path)["arr_0"] # ===== INIT DATASET AND LOAD MASK eval_dataset = AgentDataset(cfg, eval_zarr, rasterizer, agents_mask=eval_mask) eval_dataloader = DataLoader(eval_dataset, shuffle=eval_cfg["shuffle"], batch_size=eval_cfg["batch_size"], num_workers=eval_cfg["num_workers"]) print(eval_dataset) def model_validation_score(model, pred_path): # ==== EVAL LOOP model.eval() torch.set_grad_enabled(False)
from l5kit.rasterization import build_rasterizer from l5kit.configs import load_config_data from l5kit.data import ChunkedDataset from l5kit.dataset import AgentDataset import numpy as np # Declare global variables zarr_dt = ChunkedDataset("/home/majoradi/Documents/l5_sample/sample.zarr") zarr_dt.open() AGENTS = zarr_dt.agents FRAMES = zarr_dt.frames SCENES = zarr_dt.scenes datasetSize = 5 cols = ["frame_id", "object_id", "object_type", "posx", "posy", "posz", "velx", "vely", "length", "width", "height", "heading"] dataset_dict = {} dataset_list = [] for col in cols: dataset_dict[col] = 0 # ADDING AGENT INFORMATION def addAgentInformation(): for idx in range(datasetSize): dataset_dict["posx"], dataset_dict["posy"] = AGENTS[idx][0] dataset_dict["length"], dataset_dict["width"], dataset_dict["height"] = AGENTS[idx][1] dataset_dict["heading"] = AGENTS[idx][2] dataset_dict["velx"], dataset_dict["vely"] = AGENTS[idx][3] dataset_dict["object_id"] = AGENTS[idx][4] # Not sure label_probabilities = np.array(AGENTS[idx][5]) obj_type = np.where(label_probabilities == 1)[0] + 1 # Not sure
def create_chopped_dataset( zarr_path: str, th_agent_prob: float, num_frames_to_copy: int, num_frames_gt: int, min_frame_future: int ) -> str: """ Create a chopped version of the zarr that can be used as a test set. This function was used to generate the test set for the competition so that the future GT is not in the data. Store: - a dataset where each scene has been chopped at `num_frames_to_copy` frames; - a mask for agents for those final frames based on the original mask and a threshold on the future_frames; - the GT csv for those agents For the competition, only the first two (dataset and mask) will be available in the notebooks Args: zarr_path (str): input zarr path to be chopped th_agent_prob (float): threshold over agents probabilities used in select_agents function num_frames_to_copy (int): number of frames to copy from the beginning of each scene, others will be discarded min_frame_future (int): minimum number of frames that must be available in the future for an agent num_frames_gt (int): number of future predictions to store in the GT file Returns: str: the parent folder of the new datam """ zarr_path = Path(zarr_path) dest_path = zarr_path.parent / f"{zarr_path.stem}_chopped_{num_frames_to_copy}" chopped_path = dest_path / zarr_path.name gt_path = dest_path / "gt.csv" mask_chopped_path = dest_path / "mask" # Create standard mask for the dataset so we can use it to filter out unreliable agents zarr_dt = ChunkedDataset(str(zarr_path)) zarr_dt.open() agents_mask_path = Path(zarr_path) / f"agents_mask/{th_agent_prob}" if not agents_mask_path.exists(): # don't check in root but check for the path select_agents( zarr_dt, th_agent_prob=th_agent_prob, th_yaw_degree=TH_YAW_DEGREE, th_extent_ratio=TH_EXTENT_RATIO, th_distance_av=TH_DISTANCE_AV, ) agents_mask_origin = np.asarray(convenience.load(str(agents_mask_path))) # create chopped dataset zarr_scenes_chop(str(zarr_path), str(chopped_path), num_frames_to_copy=num_frames_to_copy) zarr_chopped = ChunkedDataset(str(chopped_path)) zarr_chopped.open() # compute the chopped boolean mask, but also the original one limited to frames of interest for GT csv agents_mask_chop_bool = np.zeros(len(zarr_chopped.agents), dtype=np.bool) agents_mask_orig_bool = np.zeros(len(zarr_dt.agents), dtype=np.bool) for idx in range(len(zarr_dt.scenes)): scene = zarr_dt.scenes[idx] frame_original = zarr_dt.frames[scene["frame_index_interval"][0] + num_frames_to_copy - 1] slice_agents_original = get_agents_slice_from_frames(frame_original) frame_chopped = zarr_chopped.frames[zarr_chopped.scenes[idx]["frame_index_interval"][-1] - 1] slice_agents_chopped = get_agents_slice_from_frames(frame_chopped) mask = agents_mask_origin[slice_agents_original][:, 1] >= min_frame_future agents_mask_orig_bool[slice_agents_original] = mask.copy() agents_mask_chop_bool[slice_agents_chopped] = mask.copy() # store the mask and the GT csv of frames on interest np.savez(str(mask_chopped_path), agents_mask_chop_bool) export_zarr_to_csv(zarr_dt, str(gt_path), num_frames_gt, th_agent_prob, agents_mask=agents_mask_orig_bool) return str(dest_path)
def test_zarr_concat(dmg: LocalDataManager, tmp_path: Path, zarr_dataset: ChunkedDataset) -> None: concat_count = 4 zarr_input_path = dmg.require("single_scene.zarr") zarr_output_path = str(tmp_path / f"{uuid4()}.zarr") zarr_concat([zarr_input_path] * concat_count, zarr_output_path) zarr_cat_dataset = ChunkedDataset(zarr_output_path) zarr_cat_dataset.open() # check lens of arrays assert len( zarr_cat_dataset.scenes) == len(zarr_dataset.scenes) * concat_count assert len( zarr_cat_dataset.frames) == len(zarr_dataset.frames) * concat_count assert len( zarr_cat_dataset.agents) == len(zarr_dataset.agents) * concat_count assert len( zarr_cat_dataset.tl_faces) == len(zarr_dataset.tl_faces) * concat_count # check the first and last element concat_count times # TODO refactor to test all elements input_scene_a = zarr_dataset.scenes[0] input_scene_b = zarr_dataset.scenes[-1] input_frame_a = zarr_dataset.frames[0] input_frame_b = zarr_dataset.frames[-1] input_agent_a = zarr_dataset.agents[0] input_agent_b = zarr_dataset.agents[-1] input_tl_a = zarr_dataset.tl_faces[0] input_tl_b = zarr_dataset.tl_faces[-1] for idx in range(concat_count): output_scene_a = zarr_cat_dataset.scenes[idx * len(zarr_dataset.scenes)] output_scene_b = zarr_cat_dataset.scenes[(idx + 1) * len(zarr_dataset.scenes) - 1] # check all scene fields assert output_scene_a["host"] == input_scene_a["host"] assert output_scene_a["start_time"] == input_scene_a["start_time"] assert output_scene_a["end_time"] == input_scene_a["end_time"] displace_frame = len(zarr_dataset.frames) * idx assert np.all(output_scene_a["frame_index_interval"] == input_scene_a["frame_index_interval"] + displace_frame) assert np.all(output_scene_b["frame_index_interval"] == input_scene_b["frame_index_interval"] + displace_frame) # check all the frame fields output_frame_a = zarr_cat_dataset.frames[idx * len(zarr_dataset.frames)] output_frame_b = zarr_cat_dataset.frames[(idx + 1) * len(zarr_dataset.frames) - 1] assert np.allclose(output_frame_a["ego_rotation"], input_frame_a["ego_rotation"]) assert np.allclose(output_frame_b["ego_rotation"], input_frame_b["ego_rotation"]) assert np.allclose(output_frame_a["ego_translation"], input_frame_a["ego_translation"]) assert np.allclose(output_frame_b["ego_translation"], input_frame_b["ego_translation"]) assert output_frame_a["timestamp"] == input_frame_a["timestamp"] assert output_frame_b["timestamp"] == input_frame_b["timestamp"] displace_agent = len(zarr_dataset.agents) * idx assert np.all(output_frame_a["agent_index_interval"] == input_frame_a["agent_index_interval"] + displace_agent) assert np.all(output_frame_b["agent_index_interval"] == input_frame_b["agent_index_interval"] + displace_agent) displace_tl = len(zarr_dataset.tl_faces) * idx assert np.all(output_frame_a["traffic_light_faces_index_interval"] == input_frame_a["traffic_light_faces_index_interval"] + displace_tl) assert np.all(output_frame_b["traffic_light_faces_index_interval"] == input_frame_b["traffic_light_faces_index_interval"] + displace_tl) # check agents output_agent_a = zarr_cat_dataset.agents[idx * len(zarr_dataset.agents)] output_agent_b = zarr_cat_dataset.agents[(idx + 1) * len(zarr_dataset.agents) - 1] assert output_agent_a == input_agent_a assert output_agent_b == input_agent_b # check tfl output_tl_a = zarr_cat_dataset.tl_faces[idx * len(zarr_dataset.tl_faces)] output_tl_b = zarr_cat_dataset.tl_faces[(idx + 1) * len(zarr_dataset.tl_faces) - 1] assert output_tl_a == input_tl_a assert output_tl_b == input_tl_b
return len(self.get_frame_arguments) def __getitem__(self, index: int) -> dict: track_id, scene_index, state_index = self.get_frame_arguments[index] # return self.ego_dataset.get_frame(scene_index, state_index, track_id=track_id) return get_frame_custom(self.ego_dataset, scene_index, state_index, track_id=track_id) if __name__ == "__main__": from l5kit.data import LocalDataManager from l5kit.rasterization import build_rasterizer from lib.utils.yaml_utils import load_yaml repo_root = Path(__file__).parent.parent.parent.parent dm = LocalDataManager(local_data_folder=str(repo_root / "input" / "lyft-motion-prediction-autonomous-vehicles")) dataset = ChunkedDataset(dm.require("scenes/sample.zarr")).open(cached=False) cfg = load_yaml(repo_root / "src" / "modeling" / "configs" / "0905_cfg.yaml") rasterizer = build_rasterizer(cfg, dm) faster_agent_dataset = FasterAgentDataset(cfg, dataset, rasterizer, None) fast_agent_dataset = FastAgentDataset(cfg, dataset, rasterizer, None) assert len(faster_agent_dataset) == len(fast_agent_dataset) keys = ["image", "target_positions", "target_availabilities"] for index in tqdm(range(min(1000, len(faster_agent_dataset)))): actual = faster_agent_dataset[index] expected = fast_agent_dataset[index] for key in keys: assert (actual[key] == expected[key]).all()
args = parser.parse_args() cycled_colors = plt.get_cmap("tab20b")(np.linspace(0, 1, 20)) cycled_colors_targets = plt.get_cmap("Set2")(np.linspace(0, 1, 8)) intersection_i = args.intersection_i dataset_basename = args.dataset_basename vis_inputs = args.vis_inputs vis_predictions = args.vis_predictions start_scene = args.start_scene end_scene = args.end_scene scenes_per_video = args.scenes_per_video dataset_path = f"input/scenes/{dataset_basename}_filtered_min_frame_history_4_min_frame_future_1_with_mask_idx.zarr" zarr_dataset_filtered = ChunkedDataset(dataset_path) zarr_dataset_filtered.open() cfg = load_config_data("input/visualisation_config.yaml") cfg["raster_params"]["map_type"] = "py_semantic" dm = LocalDataManager() rast = build_rasterizer(cfg, dm) dataset_filtered = EgoDataset(cfg, zarr_dataset_filtered, rast) frame_dataset = FramesDataset(dataset_path) def plot_line( ax, line_id, speed=None, completion=None,
print("Load dataset...") train_cfg = cfg["train_data_loader"] valid_cfg = cfg["valid_data_loader"] # Rasterizer rasterizer = build_rasterizer(cfg, dm) # Train dataset/dataloader def transform(batch): return batch["image"], batch["target_positions"], batch[ "target_availabilities"] train_path = "scenes/sample.zarr" if debug else train_cfg["key"] train_zarr = ChunkedDataset(dm.require(train_path)).open() print("train_zarr", type(train_zarr)) train_agent_dataset = AgentDataset(cfg, train_zarr, rasterizer) train_dataset = TransformDataset(train_agent_dataset, transform) if debug: # Only use 1000 dataset for fast check... train_dataset = Subset(train_dataset, np.arange(1000)) train_loader = DataLoader(train_dataset, shuffle=train_cfg["shuffle"], batch_size=train_cfg["batch_size"], num_workers=train_cfg["num_workers"]) print(train_agent_dataset) valid_path = "scenes/sample.zarr" if debug else valid_cfg["key"] valid_zarr = ChunkedDataset(dm.require(valid_path)).open() print("valid_zarr", type(train_zarr))
def zarr_dataset(dmg: LocalDataManager) -> ChunkedDataset: zarr_path = dmg.require("single_scene.zarr") zarr_dataset = ChunkedDataset(path=zarr_path) zarr_dataset.open() return zarr_dataset
def evaluate(model, device, data_path): # set env variable for data os.environ["L5KIT_DATA_FOLDER"] = data_path dm = LocalDataManager(None) cfg = model.cfg # ===== INIT DATASET test_cfg = cfg["test_data_loader"] # Rasterizer rasterizer = build_rasterizer(cfg, dm) # Test dataset/dataloader test_zarr = ChunkedDataset(dm.require(test_cfg["key"])).open() test_mask = np.load(f"{data_path}/scenes/mask.npz")["arr_0"] test_dataset = AgentDataset(cfg, test_zarr, rasterizer, agents_mask=test_mask) test_dataloader = DataLoader(test_dataset, shuffle=test_cfg["shuffle"], batch_size=test_cfg["batch_size"], num_workers=test_cfg["num_workers"]) test_dataloader = test_dataloader print(test_dataloader) # ==== EVAL LOOP model.eval() torch.set_grad_enabled(False) criterion = nn.MSELoss(reduction="none") # store information for evaluation future_coords_offsets_pd = [] timestamps = [] agent_ids = [] progress_bar = tqdm(test_dataloader) for data in progress_bar: _, outputs, _ = model.forward(data, device, criterion) future_coords_offsets_pd.append(outputs.cpu().numpy().copy()) timestamps.append(data["timestamp"].numpy().copy()) agent_ids.append(data["track_id"].numpy().copy()) # ==== Save Results pred_path = "./submission.csv" write_pred_csv(pred_path, timestamps=np.concatenate(timestamps), track_ids=np.concatenate(agent_ids), coords=np.concatenate(future_coords_offsets_pd)) # ===== GENERATE AND LOAD CHOPPED DATASET num_frames_to_chop = 56 test_cfg = cfg["test_data_loader"] test_base_path = create_chopped_dataset( zarr_path=dm.require(test_cfg["key"]), th_agent_prob=cfg["raster_params"]["filter_agents_threshold"], num_frames_to_copy=num_frames_to_chop, num_frames_gt=cfg["model_params"]["future_num_frames"], min_frame_future=MIN_FUTURE_STEPS) eval_zarr_path = str( Path(test_base_path) / Path(dm.require(test_cfg["key"])).name) print(eval_zarr_path) test_mask_path = str(Path(test_base_path) / "mask.npz") test_gt_path = str(Path(test_base_path) / "gt.csv") test_zarr = ChunkedDataset(eval_zarr_path).open() test_mask = np.load(test_mask_path)["arr_0"] # ===== INIT DATASET AND LOAD MASK test_dataset = AgentDataset(cfg, test_zarr, rasterizer, agents_mask=test_mask) test_dataloader = DataLoader(test_dataset, shuffle=test_cfg["shuffle"], batch_size=test_cfg["batch_size"], num_workers=test_cfg["num_workers"]) print(test_dataset) # ==== Perform Evaluation print(test_gt_path) metrics = compute_metrics_csv(test_gt_path, pred_path, [neg_multi_log_likelihood, time_displace]) for metric_name, metric_mean in metrics.items(): print(metric_name, metric_mean)
}, 'train_params': { 'max_num_steps': 101, 'checkpoint_every_n_steps': 20, } } # set env variable for data DIR_INPUT = cfg["data_path"] os.environ["L5KIT_DATA_FOLDER"] = DIR_INPUT dm = LocalDataManager(None) # ===== INIT TRAIN DATASET============================================================ train_cfg = cfg["train_data_loader"] rasterizer = build_rasterizer(cfg, dm) train_zarr = ChunkedDataset(dm.require(train_cfg["key"])).open() train_dataset = AgentDataset(cfg, train_zarr, rasterizer) train_dataloader = DataLoader(train_dataset, shuffle=train_cfg["shuffle"], batch_size=train_cfg["batch_size"], num_workers=train_cfg["num_workers"]) print( "==================================TRAIN DATA==================================" ) print(train_dataset) #====== INIT TEST DATASET============================================================= test_cfg = cfg["test_data_loader"] rasterizer = build_rasterizer(cfg, dm) test_zarr = ChunkedDataset(dm.require(test_cfg["key"])).open() test_mask = np.load(f"{DIR_INPUT}/scenes/mask.npz")["arr_0"]
def __init__( self, dset_name=None, cfg_path="./agent_motion_config.yaml", cfg_data=None, stage=None, ): if stage is not None: print( 'LyftDatasetPrerendered:: argument "stage=" is deprecated, use "dset_name=" instead' ) if dset_name is None: dset_name = stage else: raise ValueError( 'LyftDatasetPrerendered::Please use only "dset_name" argument' ) assert dset_name is not None logger.info(f"Initializing prerendered {dset_name} dataset...") self.dm = LocalDataManager(None) self.dset_name = dset_name if cfg_data is None: self.cfg = load_config_data(cfg_path) else: self.cfg = cfg_data # only used for rgb visualisation self.rasterizer = build_custom_rasterizer(self.cfg, self.dm) self.dset_cfg = self.cfg[ LyftDataset.name_2_dataloader_key[dset_name]].copy() root_dir = self.dset_cfg.get("root_dir", None) if root_dir is None: data_dir_name = { LyftDataset.DSET_TRAIN: "train_uncompressed", LyftDataset.DSET_TRAIN_XXL: "train_XXL", LyftDataset.DSET_VALIDATION: "validate_uncompressed", LyftDataset.DSET_TEST: "test", }[dset_name] self.root_dir_name = join( config.L5KIT_DATA_FOLDER, self.cfg["raster_params"]["pre_render_cache_dir"], data_dir_name) else: self.root_dir_name = join(config.L5KIT_DATA_FOLDER, root_dir) print("load pre-rendered raster from", self.root_dir_name) self.segmentation_output = self.cfg["raster_params"].get( "segmentation_output", None) self.segmentation_results_dir = self.cfg["raster_params"].get( "segmentation_results_dir", None) self.add_own_agent_mask = self.cfg["raster_params"].get( "add_own_agent_mask", False) print( f"Segmentation model res: {self.segmentation_results_dir} {self.add_own_agent_mask}" ) all_files_fn = join( self.root_dir_name, self.dset_cfg.get("filepaths_cache", "all_files") + ".npy") try: logger.info(f"Loading cached filenames from {all_files_fn}") self.all_files = np.load(all_files_fn, allow_pickle=True) except FileNotFoundError: logger.info(f"Generating and caching filenames in {all_files_fn}") self.all_files = list( sorted( glob.glob(f"{self.root_dir_name}/**/*.npz", recursive=True))) print(f"Generated all npz paths and saved to {all_files_fn}") np.save(all_files_fn, self.all_files) print(f"Found {len(self.all_files)} agents") self.add_agent_state = self.cfg["model_params"]["add_agent_state"] self.add_agent_state_history = self.cfg["model_params"].get( "add_agent_state_history", False) self.agent_state_history_steps = self.cfg["model_params"].get( "agent_state_history_steps", 20) self.max_agent_in_state_history = self.cfg["model_params"].get( "max_agent_in_state_history", 16) self.w, self.h = self.cfg["raster_params"]["raster_size"] self.tf_face_colors = {} zarr_path = self.dm.require(self.dset_cfg["key"]) print(f"Opening Chunked Dataset {zarr_path}...") # print("Creating Agent Dataset...") # self.agent_dataset = AgentDataset( # self.cfg, # self.zarr_dataset, # self.rasterizer, # min_frame_history=0, # min_frame_future=10, # ) if self.add_agent_state_history: self.zarr_dataset = ChunkedDataset(zarr_path).open() self.all_scenes = self.zarr_dataset.scenes[:].copy() self.all_frames_agent_interval = self.zarr_dataset.frames[ 'agent_index_interval'].copy() print("Creating Agent Dataset... [OK]")
preds, confidences = model(inputs) # skip compute loss if we are doing prediction loss = criterion(targets, preds, confidences, target_availabilities) if compute_loss else 0 return loss, preds, confidences if __name__ == '__main__': # 加载数据集,准备device DIR_INPUT = cfg["data_path"] os.environ["L5KIT_DATA_FOLDER"] = DIR_INPUT dm = LocalDataManager() rasterizer = build_rasterizer(cfg, dm) train_cfg = cfg["train_data_loader"] train_zarr = ChunkedDataset(dm.require(train_cfg["key"])).open( cached=False) # to prevent run out of memory train_dataset = AgentDataset(cfg, train_zarr, rasterizer) train_dataloader = DataLoader(train_dataset, shuffle=train_cfg["shuffle"], batch_size=train_cfg["batch_size"], num_workers=train_cfg["num_workers"]) print(train_dataset) rasterizer = build_rasterizer(cfg, dm) valid_cfg = cfg["valid_data_loader"] valid_zarr = ChunkedDataset(dm.require(valid_cfg["key"])).open( cached=False) # to prevent run out of memory valid_dataset = AgentDataset(cfg, valid_zarr, rasterizer) valid_dataloader = DataLoader(valid_dataset, shuffle=valid_cfg["shuffle"], batch_size=valid_cfg["batch_size"],
def __init__( self, dset_name=None, cfg_path="./agent_motion_config.yaml", cfg_data=None, stage=None, ): print(f"Initializing LyftDataset {dset_name}...") if stage is not None: print( 'DDEPRECATION WARNING! LyftDataset:: argument "stage=" is deprecated, use "dset_name=" instead' ) if dset_name is None: dset_name = stage else: raise ValueError( 'LyftDataset::Please use only "dset_name" argument') assert dset_name is not None self.dm = LocalDataManager(None) self.dset_name = dset_name if cfg_data is None: self.cfg = utils.DotDict(load_config_data(cfg_path)) else: self.cfg = utils.DotDict(cfg_data) self.dset_cfg = self.cfg[ LyftDataset.name_2_dataloader_key[dset_name]].copy() if self.cfg["raster_params"]["map_type"] == "py_satellite": print("WARNING! USING SLOW RASTERIZER!!! py_satellite") self.rasterizer = build_rasterizer(self.cfg, self.dm) self.rasterizer = build_custom_rasterizer(self.cfg, self.dm) if dset_name == LyftDataset.DSET_VALIDATION_CHOPPED: eval_base_path = Path( "/opt/data3/lyft_motion_prediction/prediction_dataset/scenes/validate_chopped_100" ) eval_zarr_path = str( Path(eval_base_path) / Path(self.dm.require(self.dset_cfg["key"])).name) eval_mask_path = str(Path(eval_base_path) / "mask.npz") self.eval_gt_path = str(Path(eval_base_path) / "gt.csv") self.zarr_dataset = ChunkedDataset(eval_zarr_path).open( cached=False) self.agent_dataset = AgentDataset( self.cfg, self.zarr_dataset, self.rasterizer, agents_mask=np.load(eval_mask_path)["arr_0"], ) self.val_chopped_gt = defaultdict(dict) for el in read_gt_csv(self.eval_gt_path): self.val_chopped_gt[el["track_id"] + el["timestamp"]] = el elif dset_name == LyftDataset.DSET_TEST: self.zarr_dataset = ChunkedDataset( self.dm.require(self.dset_cfg["key"])).open(cached=False) test_mask = np.load( f"{config.L5KIT_DATA_FOLDER}/scenes/mask.npz")["arr_0"] self.agent_dataset = AgentDataset(self.cfg, self.zarr_dataset, self.rasterizer, agents_mask=test_mask) else: zarr_path = self.dm.require(self.dset_cfg["key"]) print(f"Opening Chunked Dataset {zarr_path}...") self.zarr_dataset = ChunkedDataset(zarr_path).open(cached=False) print("Creating Agent Dataset...") self.agent_dataset = AgentDataset( self.cfg, self.zarr_dataset, self.rasterizer, min_frame_history=0, min_frame_future=10, ) print("Creating Agent Dataset... [OK]") if dset_name == LyftDataset.DSET_VALIDATION: mask_frame100 = np.zeros( shape=self.agent_dataset.agents_mask.shape, dtype=np.bool) for scene in self.agent_dataset.dataset.scenes: frame_interval = scene["frame_index_interval"] agent_index_interval = self.agent_dataset.dataset.frames[ frame_interval[0] + 99]["agent_index_interval"] mask_frame100[ agent_index_interval[0]:agent_index_interval[1]] = True prev_agents_num = np.sum(self.agent_dataset.agents_mask) self.agent_dataset.agents_mask = self.agent_dataset.agents_mask * mask_frame100 print( f"nb agent: orig {prev_agents_num} filtered {np.sum(self.agent_dataset.agents_mask)}" ) # store the valid agents indexes self.agent_dataset.agents_indices = np.nonzero( self.agent_dataset.agents_mask)[0] self.w, self.h = self.cfg["raster_params"]["raster_size"] self.add_agent_state = self.cfg["model_params"]["add_agent_state"] self.agent_state = None
# import catboost as cb # pd.set_option('max_columns', 50) os.environ["L5KIT_DATA_FOLDER"] = "data" cfg = load_config_data("examples/visualisation/visualisation_config.yaml") print(cfg) # Loading sample data for EDA # set env variable for data dm = LocalDataManager() dataset_path = dm.require('scenes/sample.zarr') zarr_dataset = ChunkedDataset(dataset_path) zarr_dataset.open() print(zarr_dataset) frames = zarr_dataset.frames agents = zarr_dataset.agents scenes = zarr_dataset.scenes # tl_faces = zarr_dataset.tl_faces # set env variable for data os.environ["L5KIT_DATA_FOLDER"] = "data"
train_agents_mask = None if flags.validation_chopped: # Use chopped dataset to calc statistics... num_frames_to_chop = 100 th_agent_prob = cfg["raster_params"]["filter_agents_threshold"] min_frame_future = 1 num_frames_to_copy = num_frames_to_chop train_agents_mask = load_mask_chopped(dm.require(train_path), th_agent_prob, num_frames_to_copy, min_frame_future) print("train_path", train_path, "train_agents_mask", train_agents_mask.shape) train_zarr = ChunkedDataset(dm.require(train_path)).open(cached=False) print("train_zarr", type(train_zarr)) print(f"Open Dataset {flags.pred_mode}...") train_agent_dataset = FasterAgentDataset( cfg, train_zarr, rasterizer, min_frame_history=flags.min_frame_history, min_frame_future=flags.min_frame_future, agents_mask=train_agents_mask) print("train_agent_dataset", len(train_agent_dataset)) n_sample = 1_000_000 # Take 1M sample. target_scale_abs_mean, target_scale_abs_max, target_scale_std = calc_target_scale( train_agent_dataset, n_sample)
def evaluate(cfg, model, dm, rasterizer, first_time, iters, eval_dataloader, eval_gt_path): if first_time: num_frames_to_chop = 100 print("min_future_steps: ",MIN_FUTURE_STEPS) eval_cfg = cfg["val_data_loader"] eval_base_path = create_chopped_dataset(dm.require(eval_cfg["key"]), cfg["raster_params"]["filter_agents_threshold"], num_frames_to_chop, cfg["model_params"]["future_num_frames"], MIN_FUTURE_STEPS) eval_zarr_path = str(Path(eval_base_path) / Path(dm.require(eval_cfg["key"])).name) eval_mask_path = str(Path(eval_base_path) / "mask.npz") eval_gt_path = str(Path(eval_base_path) / "gt.csv") eval_zarr = ChunkedDataset(eval_zarr_path).open() eval_mask = np.load(eval_mask_path)["arr_0"] eval_dataset = AgentDataset(cfg, eval_zarr, rasterizer, agents_mask=eval_mask) eval_dataloader = DataLoader(eval_dataset, shuffle=eval_cfg["shuffle"], batch_size=eval_cfg["batch_size"], num_workers=eval_cfg["num_workers"]) print(eval_dataset) first_time = False model.eval() torch.set_grad_enabled(False) future_coords_offsets_pd = [] timestamps = [] confidences_list = [] agent_ids = [] progress_bar = tqdm(eval_dataloader) for data in progress_bar: _, preds, confidences = forward(data, model) # convert agent coordinates into world offsets preds = preds.cpu().numpy() world_from_agents = data["world_from_agent"].numpy() centroids = data["centroid"].numpy() coords_offset = [] for idx in range(len(preds)): for mode in range(3): preds[idx, mode, :, :] = transform_points(preds[idx, mode, :, :], world_from_agents[idx]) - centroids[idx][:2] future_coords_offsets_pd.append(preds.copy()) confidences_list.append(confidences.cpu().numpy().copy()) timestamps.append(data["timestamp"].numpy().copy()) agent_ids.append(data["track_id"].numpy().copy()) model.train() torch.set_grad_enabled(True) pred_path = os.path.join(cfg["save_path"],f"pred_{iters}.csv") write_pred_csv(pred_path, timestamps=np.concatenate(timestamps), track_ids=np.concatenate(agent_ids), coords=np.concatenate(future_coords_offsets_pd), confs = np.concatenate(confidences_list) ) metrics = compute_metrics_csv(eval_gt_path, pred_path, [neg_multi_log_likelihood, time_displace]) for metric_name, metric_mean in metrics.items(): print(metric_name, metric_mean) return first_time, eval_dataloader, eval_gt_path
'key': 'scenes/test.zarr', 'batch_size': 8, 'shuffle': False, 'num_workers': 4 }, 'train_params': { 'checkpoint_every_n_steps': 5000, 'max_num_steps': 1000, } } zarr_loc = cfg["train_data_loader"]["key"] dm = LocalDataManager() train_zarr = ChunkedDataset(dm.require(zarr_loc)).open() #let's see what one of the objects looks like print(train_zarr) #Let us visualize the EGO DATASET rast = build_rasterizer(cfg, dm) #Rasterisation is one of the typical techniques of rendering 3D models dataset = EgoDataset(cfg, train_zarr , rast) #For visualizing the Av. Av stands for Autonomous Vehicle #Let's get a sample from the dataset and use our rasterizer to get an RGB image we can plot data = dataset[70] im = data['image'].transpose(1, 2, 0) im = dataset.rasterizer.to_rgb(im) target_positions_pixels = transform_points(data["target_positions"] + data["centroid"][:2], data["world_to_image"])
# confs=np.concatenate(confidences_list) # ) # metrics = compute_metrics_csv( # f"{DATA_DIR}/scenes/validate_chopped_100/gt.csv", # val_predictions_file, # [neg_multi_log_likelihood, time_displace], # ) # for metric_name, metric_mean in metrics.items(): # print(metric_name, metric_mean) # ====== INIT TEST DATASET============================================================= rasterizer = build_rasterizer(cfg, dm) test_zarr = ChunkedDataset(dm.require("scenes/test.zarr")).open() test_mask = np.load(f"{DATA_DIR}/scenes/mask.npz")["arr_0"] test_dataset = AgentDataset(cfg, test_zarr, rasterizer, agents_mask=test_mask) test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=32, num_workers=30) model.eval() torch.set_grad_enabled(False) # store information for evaluation future_coords_offsets_pd = [] ground_truth = [] timestamps = [] confidences_list = [] agent_ids = [] with tqdm(total=len(test_dataloader)) as progress:
def get_loaders(train_batch_size=32, valid_batch_size=64): """Prepare loaders. Args: train_batch_size (int, optional): batch size for training dataset. Default is `32`. valid_batch_size (int, optional): batch size for validation dataset. Default is `64`. Returns: train and validation data loaders """ rasterizer = build_rasterizer(cfg, dm) DATASET_CLASS = SegmentationAgentDataset train_zarr = ChunkedDataset(dm.require("scenes/train.zarr")).open() train_dataset = DATASET_CLASS(cfg, train_zarr, rasterizer) sizes = ps.read_csv(os.environ["TRAIN_TRAJ_SIZES"])["size"].values is_small = sizes < 6 n_points = is_small.sum() to_sample = n_points // 4 print(" * points - {} (points to sample - {})".format(n_points, to_sample)) print(" * paths -", sizes.shape[0] - n_points) indices = np.concatenate([ np.random.choice( np.where(is_small)[0], size=to_sample, replace=False, ), np.where(~is_small)[0], ]) train_dataset = Subset(train_dataset, indices) n_samples = len(train_dataset) // 2 train_dataset = Subset(train_dataset, list(range(n_samples))) train_loader = DataLoader( train_dataset, batch_size=train_batch_size, num_workers=NUM_WORKERS, shuffle=True, worker_init_fn=seed_all, drop_last=True, ) print(f" * Number of elements in train dataset - {len(train_dataset)}") print(f" * Number of elements in train loader - {len(train_loader)}") valid_zarr_path = dm.require("scenes/validate_chopped_10/validate.zarr") mask_path = dm.require("scenes/validate_chopped_10/mask.npz") valid_mask = np.load(mask_path)["arr_0"] valid_gt_path = dm.require("scenes/validate_chopped_10/gt.csv") valid_zarr = ChunkedDataset(valid_zarr_path).open() valid_dataset = DATASET_CLASS(cfg, valid_zarr, rasterizer, agents_mask=valid_mask) valid_loader = DataLoader( valid_dataset, batch_size=valid_batch_size, shuffle=False, num_workers=NUM_WORKERS, ) print(f" * Number of elements in valid dataset - {len(valid_dataset)}") print(f" * Number of elements in valid loader - {len(valid_loader)}") return train_loader, (valid_loader, valid_gt_path)
def dataset() -> ChunkedDataset: zarr_dataset = ChunkedDataset( path="./l5kit/tests/artefacts/single_scene.zarr") zarr_dataset.open() return zarr_dataset
def build_dataloader( cfg: Dict, split: str, data_manager: DataManager, dataset_class: Callable, rasterizer: Rasterizer, perturbation: Optional[Perturbation] = None, combine_scenes: bool = False, ) -> DataLoader: """ Function to build a dataloader from a dataset of dataset_class. Note we have to pass rasterizer and perturbation as the factory functions for those are likely to change between repos. Args: cfg (dict): configuration dict split (str): this will be used to index the cfg to get the correct datasets (train or val currently) data_manager (DataManager): manager for resolving paths dataset_class (Callable): a class object (EgoDataset or AgentDataset currently) to build the dataset rasterizer (Rasterizer): the rasterizer for the dataset perturbation (Optional[Perturbation]): an optional perturbation object combine_scenes (bool): if to combine scenes that follow up each other perfectly Returns: DataLoader: pytorch Dataloader object built with Concat and Sub datasets """ data_loader_cfg = cfg[f"{split}_data_loader"] datasets = [] for dataset_param in data_loader_cfg["datasets"]: zarr_dataset_path = data_manager.require(key=dataset_param["key"]) zarr_dataset = ChunkedDataset(path=zarr_dataset_path) zarr_dataset.open() if combine_scenes: # possible future deprecation zarr_dataset.scenes = get_combined_scenes(zarr_dataset.scenes) # Let's load the zarr dataset with our dataset. dataset = dataset_class(cfg, zarr_dataset, rasterizer, perturbation=perturbation) scene_indices = dataset_param["scene_indices"] scene_subsets = [] if dataset_param["scene_indices"][0] == -1: # TODO replace with empty scene_subset = Subset(dataset, np.arange(0, len(dataset))) scene_subsets.append(scene_subset) else: for scene_idx in scene_indices: valid_indices = dataset.get_scene_indices(scene_idx) scene_subset = Subset(dataset, valid_indices) scene_subsets.append(scene_subset) datasets.extend(scene_subsets) # Let's concatenate the training scenes into one dataset for the data loader to load from. concat_dataset: ConcatDataset = ConcatDataset(datasets) # Initialize the data loader that our training loop will iterate on. batch_size = data_loader_cfg["batch_size"] shuffle = data_loader_cfg["shuffle"] num_workers = data_loader_cfg["num_workers"] dataloader = DataLoader(dataset=concat_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) return dataloader