def _prepare_evaluation_param(param: Parameter) -> Parameter: assert param.eval_data_name assert param.eval_info_name assert param.eval_keyword if param.model_level == 'low': param.eval_keyword = fetch_road_option_from_str(param.eval_keyword.upper()) elif param.model_level == 'high': param.eval_keyword = param.eval_keyword.lower() else: logger.info(param.model_level) raise TypeError('invalid eval_keyword was given {}'.format(param.eval_keyword)) param.max_data_length = -1 param.shuffle = False param.batch_size = 1 param.dataset_data_names = [param.eval_data_name] param.dataset_info_names = [param.eval_info_name] if param.model_level == 'low': param.use_multi_cam = False param.use_sequence = False param.has_clusters = False return param
def load_expert_trajectory(data_name: str, info_name: str, data_keyword: str) -> EvaluationTrajectoryGroup: param = Parameter() param.model_level = 'high' param.eval_keyword = data_keyword param.eval_data_name = data_name param.eval_info_name = info_name data_frame_list, sentences = load_evaluation_dataset(param) info_list = [ EvaluationUnitInfo(data_keyword, -1, '', -1, i) for i in range(len(data_frame_list)) ] trajectories = [ EvaluationTrajectory(i, d, False) for i, d in zip(info_list, data_frame_list) ] info = EvaluationUnitInfo(data_keyword, -1, '', -1, -1) traj_group = EvaluationTrajectoryGroup(info, trajectories) return traj_group
def load_param_and_evaluator(eval_keyword: str, args, model_type: str): param = Parameter() low_level = model_type in ['control', 'stop'] if model_type == 'control': exp_index = args.control_model_index exp_name = args.control_model_name exp_step = args.control_model_step elif model_type == 'stop': exp_index = args.stop_model_index exp_name = args.stop_model_name exp_step = args.stop_model_step elif model_type == 'high': exp_index = args.high_level_index exp_name = args.high_level_name exp_step = args.high_level_step elif model_type == 'single': exp_index = args.single_model_index exp_name = args.single_model_name exp_step = args.single_model_step else: raise TypeError('invalid model type {}'.format(model_type)) param.exp_name = exp_name param.exp_index = exp_index param.load() param.batch_size = 1 param.eval_keyword = eval_keyword param.eval_data_name = args.eval_data_name param.eval_info_name = args.eval_info_name logger.info('model type: {}'.format(param.model_type)) cls = LowLevelEvaluator if low_level else ( SingleEvaluator if model_type == 'single' else HighLevelEvaluator) logger.info((model_type, cls, param.model_level, param.encoder_type)) eval_arg = args.exp_cmd if low_level else eval_keyword evaluator = cls(param, eval_arg) evaluator.load(step=exp_step) return param, evaluator