def test_add_random(self): """ If persist is not set to True for add_random each time we call param_grid, it samples new random values this is because persist = True saves the parameter as a list or randomly generated parameters """ ps = ParamSpace() name = "param1" ps.add_random(name, low=2, high=4, persist=False, n=10, prior="uniform") params1 = ps.param_grid() self.assertTrue(ps.size, 1) r1 = next(params1)[name] params2 = ps.param_grid() r2 = next(params2)[name] ps.write("test.cfg") self.assertNotEqual(r1, r2)
def test_param_grid(self): ps = ParamSpace() ps.add_value("p1", True) ps.add_list("p2", ["A", "B"]) ps.add_random("p3", low=0, high=4, prior="uniform", n=3) # print("param space size ", ps.grid_size) grid = ps.param_grid() # for params in grid: # print(params) grid = ps.param_grid() grid = list(grid) self.assertEqual(len(grid), 1 * 2 * 3) self.assertEqual(len(grid), ps.size)
def test_param_grid_with_id(self): ps = ParamSpace() ps.add_value("p1", True) ps.add_list("p2", ["A", "B"]) params1 = ps.param_grid(runs=5) self.assertEqual(len(list(params1)), 1 * 2 * 5)
def main(params, module, runs, name, workers, gpu, config_ids, cancel): logger = logging.getLogger(__name__) handler = logging.FileHandler('{name}.log'.format(name=name), delay=True) handler.setLevel(logging.ERROR) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) try: if gpu: # detecting available gpus with load < 0.1 worker_ids = [g.id for g in GPUtil.getGPUs() if g.load < 0.2] num_workers = min(workers, len(worker_ids)) if num_workers <= 0: logger.log(logging.ERROR, "no gpus available") sys.exit(1) else: num_workers = min(workers, mp.cpu_count()) if num_workers <= 0: logger.log(logging.ERROR, "--workers cannot be 0") sys.exit(1) ps = ParamSpace(filename=params) ps.write_configs('{}_params.csv'.format(name)) param_grid = ps.param_grid(runs=runs) n_tasks = ps.size * runs if len(config_ids) > 0: n_tasks = len(config_ids) * runs param_grid = [p for p in param_grid if p["id"] in config_ids] param_grid = iter(param_grid) num_workers = min(n_tasks, num_workers) print("----------Parameter Space Runner------------") print(":: tasks: {}".format(n_tasks)) print(":: workers: {}".format(num_workers)) print("--------------------------------------------") config_queue = Queue() result_queue = Queue() error_queue = Queue() progress_bar = tqdm(total=n_tasks, leave=True) terminate_flags = [Event() for _ in range(num_workers)] processes = [ Process(target=worker, args=(i, module, config_queue, result_queue, error_queue, terminate_flags[i], cancel)) for i in range(num_workers) ] scores = {} configs = {} # submit num worker jobs for _ in range(num_workers): next_cfg = next(param_grid) configs[next_cfg["id"]] = next_cfg config_queue.put(next_cfg) for p in processes: p.daemon = True p.start() num_completed = 0 pending = num_workers done = False successful = set() while num_completed < n_tasks and not done: try: res = result_queue.get(timeout=1) pid, cfg_id, result = res if not isinstance(result, Exception): successful.add(cfg_id) # cfg = configs[cfg_id] scores[cfg_id] = result num_completed += 1 pending -= 1 if (num_completed + pending) != n_tasks: next_cfg = next(param_grid) configs[next_cfg["id"]] = next_cfg config_queue.put(next_cfg) pending += 1 else: # signal the current worker for termination no more work to be done terminate_flags[pid].set() progress_bar.update() else: # retrieve one error from queue, might not be exactly the one that failed # since other worker can write to the queue, but we will have at least one error to retrieve _, cfg_id_err, err = error_queue.get() logger.error("configuration {} failed".format(cfg_id_err)) logger.error(err) if cancel: done = True else: num_completed += 1 pending -= 1 if (num_completed + pending) != n_tasks: next_cfg = next(param_grid) configs[next_cfg["id"]] = next_cfg config_queue.put(next_cfg) pending += 1 else: # signal the current worker for termination no more work to be done terminate_flags[pid].set() progress_bar.update() except QueueEmpty: pass # try to wait for process termination for process in processes: process.join(timeout=0.5) if process.is_alive(): process.terminate() if len(config_ids) > 0: all_ids = set(config_ids) else: all_ids = set(range(ps.size)) failed_tasks = all_ids.difference(successful) if len(failed_tasks) > 0: ids = " ".join(map(str, failed_tasks)) fail_runs = "failed runs: {}".format(ids) print(fail_runs, file=sys.stderr) logger.warn(fail_runs) progress_bar.close() except TomlDecodeError as e: logger.error(traceback.format_exc()) print("\n\n[Invalid parameter file] TOML decode error:\n {}".format(e), file=sys.stderr) except ParamDecodeError as e: logger.error(traceback.format_exc()) print("\n\n[Invalid parameter file]\n {}".format(e), file=sys.stderr)