default='MLP') args = parser.parse_args() from multiml import logger logger.set_level(args.log_level) from setup_tensorflow import setup_tensorflow setup_tensorflow(args.seed, args.igpu) from run_utils import add_suffix save_dir = add_suffix(save_dir, args) save_dir += f'_{args.model}' from multiml.saver import Saver saver = Saver(save_dir, serial_id=args.seed) saver.add("seed", args.seed) # Storegate from my_storegate import get_storegate storegate = get_storegate( data_path=args.data_path, max_events=args.max_events, ) from multiml.task_scheduler import TaskScheduler task_scheduler = TaskScheduler() subtask_args = { 'saver': saver, 'output_var_names': ('probability', ), 'true_var_names': 'label',
def preprocessing(save_dir, args, tau4vec_tasks=['MLP', 'conv2D', 'SF', 'zero', 'noise'], higgsId_tasks=['mlp', 'lstm', 'mass', 'zero', 'noise'], truth_intermediate_inputs=True): from multiml import logger logger.set_level(args.log_level) from setup_tensorflow import setup_tensorflow setup_tensorflow(args.seed, args.igpu) load_weights = args.load_weights from multiml.saver import Saver saver = Saver(save_dir, serial_id=args.seed) saver.add("seed", args.seed) # Storegate from my_storegate import get_storegate storegate = get_storegate( data_path=args.data_path, max_events=args.max_events, ) # Task scheduler from multiml.task_scheduler import TaskScheduler from my_tasks import get_higgsId_subtasks, get_tau4vec_subtasks task_scheduler = TaskScheduler() if args.remove_dummy_models: tau4vec_tasks = [ v for v in tau4vec_tasks if v not in ['zero', 'noise'] ] higgsId_tasks = [ v for v in higgsId_tasks if v not in ['zero', 'noise'] ] if len(tau4vec_tasks) > 0 and len(higgsId_tasks) > 0: subtask1 = get_higgsId_subtasks(saver, subtask_names=higgsId_tasks, truth_input=truth_intermediate_inputs, load_weights=load_weights, run_eagerly=args.run_eagerly) task_scheduler.add_task(task_id='higgsId', parents=['tau4vec'], children=[], subtasks=subtask1) subtask2 = get_tau4vec_subtasks(saver, subtask_names=tau4vec_tasks, load_weights=load_weights, run_eagerly=args.run_eagerly) task_scheduler.add_task(task_id='tau4vec', parents=[], children=['higgsId'], subtasks=subtask2) elif len(higgsId_tasks) > 0: subtask = get_higgsId_subtasks(saver, subtask_names=higgsId_tasks, load_weights=load_weights, run_eagerly=args.run_eagerly) task_scheduler.add_task(task_id='higgsId', subtasks=subtask) elif len(tau4vec_tasks) > 0: subtask = get_tau4vec_subtasks(saver, subtask_names=tau4vec_tasks, load_weights=load_weights, run_eagerly=args.run_eagerly) task_scheduler.add_task(task_id='tau4vec', subtasks=subtask) else: raise ValueError("Strange task combination...") # Metric if len(tau4vec_tasks) > 0 and len(higgsId_tasks) == 0: from multiml_htautau.task.metrics import CustomMSEMetric from my_tasks import corr_tau_4vec, truth_tau_4vec metric = CustomMSEMetric(pred_var_name=corr_tau_4vec, true_var_name=truth_tau_4vec, phase='test') else: from multiml.agent.metric import AUCMetric metric = AUCMetric(pred_var_name='probability', true_var_name='label', phase='test') return saver, storegate, task_scheduler, metric