default='MLP')
args = parser.parse_args()

from multiml import logger
logger.set_level(args.log_level)

from setup_tensorflow import setup_tensorflow
setup_tensorflow(args.seed, args.igpu)

from run_utils import add_suffix
save_dir = add_suffix(save_dir, args)
save_dir += f'_{args.model}'

from multiml.saver import Saver
saver = Saver(save_dir, serial_id=args.seed)
saver.add("seed", args.seed)

# Storegate
from my_storegate import get_storegate
storegate = get_storegate(
    data_path=args.data_path,
    max_events=args.max_events,
)

from multiml.task_scheduler import TaskScheduler
task_scheduler = TaskScheduler()

subtask_args = {
    'saver': saver,
    'output_var_names': ('probability', ),
    'true_var_names': 'label',
def preprocessing(save_dir,
                  args,
                  tau4vec_tasks=['MLP', 'conv2D', 'SF', 'zero', 'noise'],
                  higgsId_tasks=['mlp', 'lstm', 'mass', 'zero', 'noise'],
                  truth_intermediate_inputs=True):
    from multiml import logger
    logger.set_level(args.log_level)

    from setup_tensorflow import setup_tensorflow
    setup_tensorflow(args.seed, args.igpu)

    load_weights = args.load_weights

    from multiml.saver import Saver
    saver = Saver(save_dir, serial_id=args.seed)
    saver.add("seed", args.seed)

    # Storegate
    from my_storegate import get_storegate
    storegate = get_storegate(
        data_path=args.data_path,
        max_events=args.max_events,
    )

    # Task scheduler
    from multiml.task_scheduler import TaskScheduler
    from my_tasks import get_higgsId_subtasks, get_tau4vec_subtasks
    task_scheduler = TaskScheduler()

    if args.remove_dummy_models:
        tau4vec_tasks = [
            v for v in tau4vec_tasks if v not in ['zero', 'noise']
        ]
        higgsId_tasks = [
            v for v in higgsId_tasks if v not in ['zero', 'noise']
        ]

    if len(tau4vec_tasks) > 0 and len(higgsId_tasks) > 0:
        subtask1 = get_higgsId_subtasks(saver,
                                        subtask_names=higgsId_tasks,
                                        truth_input=truth_intermediate_inputs,
                                        load_weights=load_weights,
                                        run_eagerly=args.run_eagerly)
        task_scheduler.add_task(task_id='higgsId',
                                parents=['tau4vec'],
                                children=[],
                                subtasks=subtask1)

        subtask2 = get_tau4vec_subtasks(saver,
                                        subtask_names=tau4vec_tasks,
                                        load_weights=load_weights,
                                        run_eagerly=args.run_eagerly)
        task_scheduler.add_task(task_id='tau4vec',
                                parents=[],
                                children=['higgsId'],
                                subtasks=subtask2)

    elif len(higgsId_tasks) > 0:
        subtask = get_higgsId_subtasks(saver,
                                       subtask_names=higgsId_tasks,
                                       load_weights=load_weights,
                                       run_eagerly=args.run_eagerly)
        task_scheduler.add_task(task_id='higgsId', subtasks=subtask)

    elif len(tau4vec_tasks) > 0:
        subtask = get_tau4vec_subtasks(saver,
                                       subtask_names=tau4vec_tasks,
                                       load_weights=load_weights,
                                       run_eagerly=args.run_eagerly)
        task_scheduler.add_task(task_id='tau4vec', subtasks=subtask)

    else:
        raise ValueError("Strange task combination...")

    # Metric
    if len(tau4vec_tasks) > 0 and len(higgsId_tasks) == 0:
        from multiml_htautau.task.metrics import CustomMSEMetric
        from my_tasks import corr_tau_4vec, truth_tau_4vec
        metric = CustomMSEMetric(pred_var_name=corr_tau_4vec,
                                 true_var_name=truth_tau_4vec,
                                 phase='test')
    else:
        from multiml.agent.metric import AUCMetric
        metric = AUCMetric(pred_var_name='probability',
                           true_var_name='label',
                           phase='test')

    return saver, storegate, task_scheduler, metric