Esempio n. 1
0
    def test_add_random(self):
        """ If persist is not set to True for add_random
        each time we call param_grid, it samples new random values
        this is because persist = True saves the parameter as a list
        or randomly generated parameters
        """
        ps = ParamSpace()
        name = "param1"
        ps.add_random(name,
                      low=2,
                      high=4,
                      persist=False,
                      n=10,
                      prior="uniform")

        params1 = ps.param_grid()
        self.assertTrue(ps.size, 1)
        r1 = next(params1)[name]

        params2 = ps.param_grid()
        r2 = next(params2)[name]

        ps.write("test.cfg")

        self.assertNotEqual(r1, r2)
Esempio n. 2
0
    def test_param_grid(self):
        ps = ParamSpace()
        ps.add_value("p1", True)
        ps.add_list("p2", ["A", "B"])
        ps.add_random("p3", low=0, high=4, prior="uniform", n=3)
        # print("param space size ", ps.grid_size)

        grid = ps.param_grid()

        # for params in grid:
        #    print(params)

        grid = ps.param_grid()
        grid = list(grid)
        self.assertEqual(len(grid), 1 * 2 * 3)
        self.assertEqual(len(grid), ps.size)
Esempio n. 3
0
    def test_param_grid_with_id(self):
        ps = ParamSpace()
        ps.add_value("p1", True)
        ps.add_list("p2", ["A", "B"])

        params1 = ps.param_grid(runs=5)

        self.assertEqual(len(list(params1)), 1 * 2 * 5)
Esempio n. 4
0
def main(params, module, runs, name, workers, gpu, config_ids, cancel):
    logger = logging.getLogger(__name__)
    handler = logging.FileHandler('{name}.log'.format(name=name), delay=True)
    handler.setLevel(logging.ERROR)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    try:
        if gpu:
            # detecting available gpus with load < 0.1
            worker_ids = [g.id for g in GPUtil.getGPUs() if g.load < 0.2]
            num_workers = min(workers, len(worker_ids))

            if num_workers <= 0:
                logger.log(logging.ERROR, "no gpus available")
                sys.exit(1)
        else:
            num_workers = min(workers, mp.cpu_count())
            if num_workers <= 0:
                logger.log(logging.ERROR, "--workers cannot be 0")
                sys.exit(1)

        ps = ParamSpace(filename=params)
        ps.write_configs('{}_params.csv'.format(name))

        param_grid = ps.param_grid(runs=runs)
        n_tasks = ps.size * runs

        if len(config_ids) > 0:
            n_tasks = len(config_ids) * runs
            param_grid = [p for p in param_grid if p["id"] in config_ids]
            param_grid = iter(param_grid)

        num_workers = min(n_tasks, num_workers)

        print("----------Parameter Space Runner------------")
        print(":: tasks: {}".format(n_tasks))
        print(":: workers: {}".format(num_workers))
        print("--------------------------------------------")

        config_queue = Queue()
        result_queue = Queue()
        error_queue = Queue()
        progress_bar = tqdm(total=n_tasks, leave=True)

        terminate_flags = [Event() for _ in range(num_workers)]
        processes = [
            Process(target=worker,
                    args=(i, module, config_queue, result_queue, error_queue,
                          terminate_flags[i], cancel))
            for i in range(num_workers)
        ]

        scores = {}
        configs = {}

        # submit num worker jobs
        for _ in range(num_workers):
            next_cfg = next(param_grid)
            configs[next_cfg["id"]] = next_cfg
            config_queue.put(next_cfg)

        for p in processes:
            p.daemon = True
            p.start()

        num_completed = 0
        pending = num_workers
        done = False
        successful = set()

        while num_completed < n_tasks and not done:
            try:
                res = result_queue.get(timeout=1)
                pid, cfg_id, result = res
                if not isinstance(result, Exception):
                    successful.add(cfg_id)
                    # cfg = configs[cfg_id]
                    scores[cfg_id] = result
                    num_completed += 1
                    pending -= 1

                    if (num_completed + pending) != n_tasks:
                        next_cfg = next(param_grid)
                        configs[next_cfg["id"]] = next_cfg
                        config_queue.put(next_cfg)

                        pending += 1
                    else:
                        # signal the current worker for termination no more work to be done
                        terminate_flags[pid].set()

                    progress_bar.update()
                else:
                    # retrieve one error from queue, might not be exactly the one that failed
                    # since other worker can write to the queue, but we will have at least one error to retrieve
                    _, cfg_id_err, err = error_queue.get()
                    logger.error("configuration {} failed".format(cfg_id_err))
                    logger.error(err)

                    if cancel:
                        done = True
                    else:
                        num_completed += 1
                        pending -= 1

                        if (num_completed + pending) != n_tasks:
                            next_cfg = next(param_grid)
                            configs[next_cfg["id"]] = next_cfg
                            config_queue.put(next_cfg)
                            pending += 1
                        else:
                            # signal the current worker for termination no more work to be done
                            terminate_flags[pid].set()
                        progress_bar.update()

            except QueueEmpty:
                pass
            # try to wait for process termination
        for process in processes:
            process.join(timeout=0.5)

            if process.is_alive():
                process.terminate()

        if len(config_ids) > 0:
            all_ids = set(config_ids)
        else:
            all_ids = set(range(ps.size))
        failed_tasks = all_ids.difference(successful)
        if len(failed_tasks) > 0:
            ids = " ".join(map(str, failed_tasks))
            fail_runs = "failed runs: {}".format(ids)
            print(fail_runs, file=sys.stderr)
            logger.warn(fail_runs)

        progress_bar.close()

    except TomlDecodeError as e:
        logger.error(traceback.format_exc())
        print("\n\n[Invalid parameter file] TOML decode error:\n {}".format(e),
              file=sys.stderr)
    except ParamDecodeError as e:
        logger.error(traceback.format_exc())
        print("\n\n[Invalid parameter file]\n {}".format(e), file=sys.stderr)