Exemple #1
0
def test_derive_and_parse_derive():
    import io

    import numpy as np

    from aw_nas.utils.common_utils import _parse_derive_file, _dump_with_perf
    from aw_nas.common import get_search_space

    output_f = io.StringIO()
    ss = get_search_space("cnn")
    rollouts = [ss.random_sample() for _ in range(6)]
    for rollout in rollouts[:4]:
        rollout.perf = {
            "reward": np.random.rand(),
            "other_perf": np.random.rand()
        }
    for i, rollout in enumerate(rollouts[:3]):
        _dump_with_perf(rollout, "str", output_f, index=i)
    for rollout in rollouts[3:]:
        _dump_with_perf(rollout, "str", output_f)

    input_f = io.StringIO(output_f.getvalue())
    dct = _parse_derive_file(input_f)
    assert len(dct) == 4  # only 4 rollouts have performance information
    print(dct)
Exemple #2
0
def random_sample(cfg_file, out_file, n, gpu, seed, dump_mode, unique):
    LOGGER.info("CWD: %s", os.getcwd())
    LOGGER.info("CMD: %s", " ".join(sys.argv))

    setproctitle.setproctitle("awnas-random-sample cfg: {}; cwd: {}".format(
        cfg_file, os.getcwd()))

    # set gpu
    _set_gpu(gpu)
    device = torch.device(
        "cuda:{}".format(gpu) if torch.cuda.is_available() else "cpu")

    # set seed
    if seed is not None:
        LOGGER.info("Setting random seed: %d.", seed)
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)

    with open(cfg_file, "r") as f:
        cfg = yaml.safe_load(f)

    ss = _init_component(cfg, "search_space")
    sampled = 0
    ignored = 0
    rollouts = []
    genotypes = []

    while sampled < n:
        if unique:
            r = ss.random_sample()
            if r.genotype in genotypes:
                ignored += 1
                LOGGER.info("(ignored %d) Ignore duplicated arch", ignored)
            else:
                sampled += 1
                LOGGER.info("(choosed %d) Choose arch", sampled)
                rollouts.append(r)
                genotypes.append(r.genotype)
        else:
            r = ss.random_sample()
            rollouts.append(r)
            genotypes.append(r.genotype)
            sampled += 1

    with open(out_file, "w") as of:
        for i, r in enumerate(rollouts):
            _dump_with_perf(r, dump_mode, of, index=i)
Exemple #3
0
 def derive(self, n, steps=None, out_file=None):
     # # some scheduled value will be used in test too, e.g. surrogate_lr, gumbel temperature...
     # called in `load` method already
     # self.on_epoch_start(self.epoch)
     with self.controller.begin_mode("eval"):
         rollouts = self.controller.sample(n)
         save_dict = {}
         with self._open_derive_out_file(out_file) as (out_f, save_dict):
             for i_sample, rollout in enumerate(rollouts):
                 if str(rollout.genotype) in save_dict:
                     _dump_with_perf(rollout, "str", out_f, index=i_sample)
                     rollout.set_perf(save_dict[str(rollout.genotype)])
                     continue
                 rollout = self.evaluator.evaluate_rollouts(
                     [rollout], is_training=False, eval_batches=steps)[0]
                 print("Finish test {}/{}\r".format(i_sample + 1, n),
                       end="")
                 if out_f is not None:
                     _dump_with_perf(rollout, "str", out_f, index=i_sample)
     return rollouts
Exemple #4
0
def derive(cfg_file, load, out_file, n, save_plot, test, steps, gpu, seed,
           dump_mode, runtime_save):
    LOGGER.info("CWD: %s", os.getcwd())
    LOGGER.info("CMD: %s", " ".join(sys.argv))

    setproctitle.setproctitle("awnas-derive config: {}; load: {}; cwd: {}"\
                              .format(cfg_file, load, os.getcwd()))

    # set gpu
    _set_gpu(gpu)
    device = torch.device(
        "cuda:{}".format(gpu) if torch.cuda.is_available() else "cpu")

    # set seed
    if seed is not None:
        LOGGER.info("Setting random seed: %d.", seed)
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)

    # load components config
    LOGGER.info("Loading configuration files.")
    with open(cfg_file, "r") as f:
        cfg = yaml.safe_load(f)

    # initialize components
    LOGGER.info("Initializing components.")
    search_space, controller = _init_components_from_cfg(cfg,
                                                         device,
                                                         controller_only=True)

    # create the directory for saving plots
    if save_plot is not None:
        save_plot = utils.makedir(save_plot)

    if not test:
        controller_path = os.path.join(load, "controller")
        controller.load(controller_path)
        controller.set_mode("eval")
        rollouts = controller.sample(n)
        with open(out_file, "w") as of:
            for i, r in enumerate(rollouts):
                if save_plot is not None:
                    r.plot_arch(filename=os.path.join(save_plot, str(i)),
                                label="Derive {}".format(i))
                _dump_with_perf(r, dump_mode, of, index=i)
                # of.write("# ---- Arch {} ----\n".format(i))
                # if r.perf:
                #     of.write("# Perfs: {}\n".format(", ".join(
                #         ["{}: {:.4f}".format(perf_name, value)
                #          for perf_name, value in r.perf.items()])))
                # _dump(r, dump_mode, of)
                # of.write("\n")
    else:
        trainer = _init_components_from_cfg(
            cfg, device)[-1]  #, from_controller=True,
        #search_space=search_space, controller=controller)[-1]

        LOGGER.info("Loading from disk...")
        trainer.setup(load=load)
        LOGGER.info("Deriving and testing...")
        if runtime_save:
            rollouts = trainer.derive(n, steps, out_file=out_file)
        else:
            rollouts = trainer.derive(n, steps)
        accs = [r.get_perf() for r in rollouts]
        idxes = np.argsort(accs)[::-1]  # sorted according to the`reward` value
        with open(out_file, "w") as of:
            for i, idx in enumerate(idxes):
                rollout = rollouts[idx]
                if save_plot is not None:
                    rollout.plot_arch(filename=os.path.join(save_plot, str(i)),
                                      label="Derive {}; Reward {:.3f}".format(
                                          i, rollout.get_perf()))
                _dump_with_perf(rollout, dump_mode, of, index=i)