def get_search_iteration_manager(experiment_group): if SearchAlgorithms.is_hyperband(experiment_group.search_algorithm): return HyperbandIterationManager(experiment_group=experiment_group) if SearchAlgorithms.is_bo(experiment_group.search_algorithm): return BOIterationManager(experiment_group=experiment_group) return None
def get_iteration_config(search_algorithm, iteration=None): if SearchAlgorithms.is_hyperband(search_algorithm): if not iteration: raise ValueError('No iteration was provided') return HyperbandIterationConfig.from_dict(iteration) if SearchAlgorithms.is_bo(search_algorithm): if not iteration: raise ValueError('No iteration was provided') return BOIterationConfig.from_dict(iteration) return BaseIterationConfig.from_dict(iteration)
def get_search_algorithm_manager(hptuning_config): if not hptuning_config: return None if SearchAlgorithms.is_grid(hptuning_config.search_algorithm): return GridSearchManager(hptuning_config=hptuning_config) if SearchAlgorithms.is_random(hptuning_config.search_algorithm): return RandomSearchManager(hptuning_config=hptuning_config) if SearchAlgorithms.is_hyperband(hptuning_config.search_algorithm): return HyperbandSearchManager(hptuning_config=hptuning_config) if SearchAlgorithms.is_bo(hptuning_config.search_algorithm): return BOSearchManager(hptuning_config=hptuning_config) return None
def start(experiment_group): task = None if SearchAlgorithms.is_grid(experiment_group.search_algorithm): task = HPCeleryTasks.HP_GRID_SEARCH_START elif SearchAlgorithms.is_random(experiment_group.search_algorithm): task = HPCeleryTasks.HP_RANDOM_SEARCH_START elif SearchAlgorithms.is_hyperband(experiment_group.search_algorithm): task = HPCeleryTasks.HP_HYPERBAND_START elif SearchAlgorithms.is_bo(experiment_group.search_algorithm): task = HPCeleryTasks.HP_BO_START if task: celery_app.send_task( task, kwargs={'experiment_group_id': experiment_group.id}, countdown=1)
def create(experiment_group): if SearchAlgorithms.is_grid(experiment_group.search_algorithm): auditor.record(event_type=EXPERIMENT_GROUP_GRID, instance=experiment_group) return grid.create(experiment_group=experiment_group) elif SearchAlgorithms.is_random(experiment_group.search_algorithm): auditor.record(event_type=EXPERIMENT_GROUP_RANDOM, instance=experiment_group) return random.create(experiment_group=experiment_group) elif SearchAlgorithms.is_hyperband(experiment_group.search_algorithm): auditor.record(event_type=EXPERIMENT_GROUP_HYPERBAND, instance=experiment_group) return hyperband.create(experiment_group=experiment_group) elif SearchAlgorithms.is_bo(experiment_group.search_algorithm): auditor.record(event_type=EXPERIMENT_GROUP_BO, instance=experiment_group) return bo.create(experiment_group=experiment_group) return None