def load_dagger_generator(args): port = args.port ports = args.ports assert ports assert port in ports assert args.dataset_name assert args.eval_dataset_name port_index = ports.index(port) eval_param = Parameter() eval_param.exp_index = args.exp_index eval_param.exp_name = args.exp_name eval_param.load() eval_param.batch_size = 1 eval_param.dataset_data_names = [args.dataset_name] eval_param.eval_data_name = args.eval_dataset_name eval_param.max_data_length = -1 if eval_param.split_train: train_dataset, valid_dataset = fetch_dataset_pair(eval_param) else: train_dataset = fetch_dataset(eval_param) num_data = len(train_dataset) index_func = partial(fetch_index_from_road_option, low_level=eval_param.use_low_level_segment) data_list = [] for i in range(num_data): # road_option, data_frame = train_dataset.get_trajectory_data(i) # images = data_frame.images # drives = data_frame.drives # fixme: this may not work if train_dataset is a HighLevelDataset? Cuz it does not even have this method! road_option, images, drives = train_dataset.get_trajectory_data(i) data_list.append({ 'road_option': road_option, 'action_index': index_func(road_option), 'src_transform': drives[0].state.transform, 'dst_location': drives[-1].state.transform.location, 'length': len(images) }) def chunker_list(seq, size): return (seq[i::size] for i in range(size)) index_data_lists = list( chunker_list(list(enumerate(data_list)), len(ports))) index_data_list = index_data_lists[port_index] return DaggerGeneratorEnvironment(args, eval_param, index_data_list)
def load_param_and_evaluator(eval_keyword: str, args, model_type: str): param = Parameter() low_level = model_type in ['control', 'stop'] if model_type == 'control': exp_index = args.control_model_index exp_name = args.control_model_name exp_step = args.control_model_step elif model_type == 'stop': exp_index = args.stop_model_index exp_name = args.stop_model_name exp_step = args.stop_model_step elif model_type == 'high': exp_index = args.high_level_index exp_name = args.high_level_name exp_step = args.high_level_step elif model_type == 'single': exp_index = args.single_model_index exp_name = args.single_model_name exp_step = args.single_model_step else: raise TypeError('invalid model type {}'.format(model_type)) param.exp_name = exp_name param.exp_index = exp_index param.load() param.batch_size = 1 param.eval_keyword = eval_keyword param.eval_data_name = args.eval_data_name param.eval_info_name = args.eval_info_name logger.info('model type: {}'.format(param.model_type)) cls = LowLevelEvaluator if low_level else ( SingleEvaluator if model_type == 'single' else HighLevelEvaluator) logger.info((model_type, cls, param.model_level, param.encoder_type)) eval_arg = args.exp_cmd if low_level else eval_keyword evaluator = cls(param, eval_arg) evaluator.load(step=exp_step) return param, evaluator