Пример #1
0
 def gradient(self, loss, return_grads=True, zero_grads=True):
     if zero_grads:
         self.zero_grad()
     _loss = loss + self._entropy_loss()
     _loss.backward()
     if return_grads:
         return utils.get_numpy(_loss), [
             (k, v.grad.clone()) for k, v in self.named_parameters()
         ]
     return utils.get_numpy(_loss)
Пример #2
0
    def summary(self, rollouts, log=False, log_prefix="", step=None):
        num = len(rollouts)
        logits_list = [[utils.get_numpy(logits) for logits in r.logits]
                       for r in rollouts]
        _ss = self.search_space
        if self.gumbel_hard:
            cg_logprobs = [0. for _ in range(_ss.num_cell_groups)]
        cg_entros = [0. for _ in range(_ss.num_cell_groups)]
        for rollout, logits in zip(rollouts, logits_list):
            for cg_idx, (vec,
                         cg_logits) in enumerate(zip(rollout.arch, logits)):
                prob = utils.softmax(cg_logits)
                logprob = np.log(prob)
                if self.gumbel_hard:
                    inds = np.argmax(utils.get_numpy(vec.op_weights), axis=-1)
                    cg_logprobs[cg_idx] += np.sum(logprob[range(len(inds)),
                                                          inds])
                cg_entros[cg_idx] += -(prob * logprob).sum()

        # mean across rollouts
        if self.gumbel_hard:
            cg_logprobs = [s / num for s in cg_logprobs]
            total_logprob = sum(cg_logprobs)
            cg_logprobs_str = ",".join(
                ["{:.2f}".format(n) for n in cg_logprobs])

        cg_entros = [s / num for s in cg_entros]
        total_entro = sum(cg_entros)
        cg_entro_str = ",".join(["{:.2f}".format(n) for n in cg_entros])

        if log:
            # maybe log the summary
            self.logger.info("%s%d rollouts: %s ENTROPY: %2f (%s)",
                             log_prefix, num,
                             "-LOG_PROB: %.2f (%s) ;" % (-total_logprob, cg_logprobs_str) \
                                 if self.gumbel_hard else "",
                             total_entro, cg_entro_str)
            if step is not None and not self.writer.is_none():
                if self.gumbel_hard:
                    self.writer.add_scalar("log_prob", total_logprob, step)
                self.writer.add_scalar("entropy", total_entro, step)

        stats = [(n + " ENTRO", entro)
                 for n, entro in zip(_ss.cell_group_names, cg_entros)]
        if self.gumbel_hard:
            stats += [(n + " LOGPROB", logprob) for n, logprob in \
                      zip(_ss.cell_group_names, cg_logprobs)]
        return OrderedDict(stats)
Пример #3
0
 def _eval_reward_func(self, data, cand_net, criterions, rollout, callback, kwargs):
     res = cand_net.eval_data(data, criterions=criterions, mode="train", **kwargs)
     rollout.set_perfs(OrderedDict(zip(self._all_perf_names, res)))
     if callback is not None:
         callback(rollout)
     # set reward to be the scalar
     rollout.set_perf(utils.get_numpy(rollout.get_perf(name="reward")))
     return res
Пример #4
0
    def gradient(self, loss, return_grads=True, zero_grads=True):
        if zero_grads:
            self.zero_grad()

        if self.inspect_hessian:
            for name, param in self.named_parameters():
                max_eig = utils.torch_utils.max_eig_of_hessian(loss, param)
                self.logger.info("Max eigenvalue of Hessian of %s: %f", name,
                                 max_eig)

            self.inspect_hessian = False

        _loss = loss + self._entropy_loss()
        _loss.backward()
        if return_grads:
            return utils.get_numpy(_loss), [
                (k, v.grad.clone()) for k, v in self.named_parameters()
            ]
        return utils.get_numpy(_loss)
Пример #5
0
    def summary(self, rollouts, log=False, log_prefix="", step=None):
        # log the total negative log prob and the entropies, averaged across the rollouts
        # also the averaged info for each cell group
        cg_logprobs = np.mean(np.array([[utils.get_numpy(cg_lp).sum() \
                                         for cg_lp in r.info["log_probs"]]\
                                        for r in rollouts]), 0)
        total_logprob = cg_logprobs.sum()
        cg_logprobs_str = ",".join(["{:.2f}".format(n) for n in cg_logprobs])
        cg_entros = np.mean(np.array([[utils.get_numpy(cg_e).sum() \
                                         for cg_e in r.info["entropies"]]\
                                        for r in rollouts]), 0)
        total_entro = cg_entros.sum()
        cg_entro_str = ",".join(["{:.2f}".format(n) for n in cg_entros])
        num = len(rollouts)

        rewards = [r.get_perf("reward") for r in rollouts]
        if rewards[0] is not None:
            total_reward = np.mean(rewards)
        else:
            total_reward = None

        if log:
            # maybe log the summary
            self.logger.info(
                "%s%d rollouts: -LOG_PROB: %.2f (%s) ; ENTROPY: %2f (%s)",
                log_prefix, num, -total_logprob, cg_logprobs_str, total_entro,
                cg_entro_str)
            if step is not None and not self.writer.is_none():
                self.writer.add_scalar("log_prob", total_logprob, step)
                self.writer.add_scalar("entropy", total_entro, step)

        # return the stats
        _ss = self.search_space
        stats = [(n + " LOGPROB", logprob)
                 for n, logprob in zip(_ss.cell_group_names, cg_logprobs)] +\
                     [(n + " ENTRO", entro)
                      for n, entro in zip(_ss.cell_group_names, cg_entros)]
        if total_reward is not None:
            stats += [("reward", total_reward)]
        return OrderedDict(stats)
Пример #6
0
 def discretized_arch_and_prob(self):
     if self._discretized_arch is None:
         if self.arch[0].ndimension() == 2:
             self._discretized_arch, self._edge_probs = self.parse(self.sampled)
         else:
             assert self.arch[0].ndimension() == 3
             self.logger.warning("Rollout batch size > 1, use logits instead of samples"
                                 "to parse the discretized arch.")
             # if multiple arch samples per step is used, (2nd dim of sampled/arch is
             # batch_size dim). use softmax(logits) to parse discretized arch
             self._discretized_arch, self._edge_probs = \
                                     self.parse(utils.softmax(utils.get_numpy(self.logits)))
     return self._discretized_arch, self._edge_probs
Пример #7
0
 def sample(self, n=1, batch_size=1):
     rollouts = []
     for _ in range(n):
         arch_list = []
         sampled_list = []
         logits_list = []
         for alpha in self.cg_alphas:
             if self.force_uniform:  # cg_alpha parameters will not be in the graph
                 alpha = torch.zeros_like(alpha)
             if batch_size > 1:
                 expanded_alpha = alpha.reshape([alpha.shape[0], 1, alpha.shape[1]])\
                                       .repeat([1, batch_size, 1])\
                                       .reshape([-1, alpha.shape[-1]])
             else:
                 expanded_alpha = alpha
             if self.use_prob:
                 # probability as sample
                 sampled = F.softmax(expanded_alpha /
                                     self.gumbel_temperature,
                                     dim=-1)
             else:
                 # gumbel sampling
                 sampled, _ = utils.gumbel_softmax(expanded_alpha,
                                                   self.gumbel_temperature,
                                                   hard=False)
             if self.gumbel_hard:
                 arch = utils.straight_through(sampled)
             else:
                 arch = sampled
             if batch_size > 1:
                 sampled = sampled.reshape([-1, batch_size, arch.shape[-1]])
                 arch = arch.reshape([-1, batch_size, arch.shape[-1]])
             arch_list.append(arch)
             sampled_list.append(utils.get_numpy(sampled))
             logits_list.append(utils.get_numpy(alpha))
         rollouts.append(
             DiffRollout(arch_list, sampled_list, logits_list,
                         self.search_space))
     return rollouts
Пример #8
0
def test_diff_controller_use_prob():
    from aw_nas import utils
    import numpy as np
    from aw_nas.controller import DiffController

    search_space = get_search_space(cls="cnn")
    device = "cuda"
    controller = DiffController(search_space, device, use_prob=True)

    assert controller.cg_alphas[0].shape == (
        14, len(search_space.shared_primitives))
    rollouts = controller.sample(3)
    assert np.abs((utils.get_numpy(rollouts[0].sampled[0]) - utils.softmax(rollouts[0].logits[0])))\
             .mean() < 1e-6
    assert isinstance(rollouts[0].genotype, search_space.genotype_type)
Пример #9
0
    def discretized_arch_and_prob(self):
        if self._discretized_arch is None:
            if self.arch[0].op_weights.ndimension() == 2:
                if self.arch[0].edge_norms is None:
                    weights = self.sampled
                else:
                    weights = []
                    for cg_sampled, (_, cg_edge_norms) in zip(
                            self.sampled, self.arch):
                        cg_edge_norms = utils.get_numpy(cg_edge_norms)[:, None]
                        weights.append(
                            utils.get_numpy(cg_sampled) * cg_edge_norms)

                self._discretized_arch, self._edge_probs = self.parse(weights)
            else:
                assert self.arch[0].op_weights.ndimension() == 3
                self.logger.warning(
                    "Rollout batch size > 1, use logits instead of samples"
                    "to parse the discretized arch.")
                # if multiple arch samples per step is used, (2nd dim of sampled/arch is
                # batch_size dim). use softmax(logits) to parse discretized arch
                self._discretized_arch, self._edge_probs = \
                                        self.parse(utils.softmax(utils.get_numpy(self.logits)))
        return self._discretized_arch, self._edge_probs
Пример #10
0
    def _init_criterions(self, rollout_type):
        # criterion and forward keyword arguments for evaluating rollout in `evaluate_rollout`

        # support compare rollout
        assert "differentiable" in rollout_type

        # NOTE: only handle differentiable rollout differently
        self._reward_func = partial(
            self.objective.get_loss,
            add_controller_regularization=True,
            add_evaluator_regularization=False,
        )
        self._reward_kwargs = {"detach_arch": False}
        self._scalar_reward_func = lambda *args, **kwargs: utils.get_numpy(
            self._reward_func(*args, **kwargs)
        )

        self._perf_names = self.objective.perf_names()
        self._all_perf_names = utils.flatten_list(["reward", "loss", self._perf_names])
        # criterion funcs for meta parameter training
        self._eval_loss_func = partial(
            self.objective.get_loss,
            add_controller_regularization=False,
            add_evaluator_regularization=True,
        )
        # criterion funcs for log/report
        self._report_loss_funcs = [
            partial(
                self.objective.get_loss_item,
                add_controller_regularization=False,
                add_evaluator_regularization=False,
            ),
            self.objective.get_perfs,
        ]
        self._criterions_related_attrs = [
            "_reward_func",
            "_reward_kwargs",
            "_scalar_reward_func",
            "_reward_kwargs",
            "_perf_names",
            "_eval_loss_func",
            "_report_loss_funcs",
        ]
Пример #11
0
    def sample(self, n=1, batch_size=1):
        rollouts = []
        for _ in range(n):
            # op_weights.shape: [num_edges, [batch_size,] num_ops]
            # edge_norms.shape: [num_edges] do not have batch_size.
            op_weights_list = []
            edge_norms_list = []
            sampled_list = []
            logits_list = []

            for alphas in self.cg_alphas:
                if self.force_uniform:  # cg_alpha parameters will not be in the graph
                    # NOTE: `force_uniform` config does not affects edge_norms (betas),
                    # if one wants a force_uniform search, keep `use_edge_normalization=False`
                    alphas = torch.zeros_like(alphas)

                if batch_size > 1:
                    expanded_alpha = alphas.reshape([alphas.shape[0], 1, alphas.shape[1]]) \
                        .repeat([1, batch_size, 1]) \
                        .reshape([-1, alphas.shape[-1]])
                else:
                    expanded_alpha = alphas

                if self.use_prob:
                    # probability as sample
                    sampled = F.softmax(expanded_alpha /
                                        self.gumbel_temperature,
                                        dim=-1)
                else:
                    # gumbel sampling
                    sampled, _ = utils.gumbel_softmax(expanded_alpha,
                                                      self.gumbel_temperature,
                                                      hard=False)

                if self.gumbel_hard:
                    op_weights = utils.straight_through(sampled)
                else:
                    op_weights = sampled

                if batch_size > 1:
                    sampled = sampled.reshape(
                        [-1, batch_size, op_weights.shape[-1]])
                    op_weights = op_weights.reshape(
                        [-1, batch_size, op_weights.shape[-1]])

                op_weights_list.append(op_weights)
                sampled_list.append(utils.get_numpy(sampled))
                logits_list.append(utils.get_numpy(alphas))

            if self.use_edge_normalization:
                for i_cg, betas in enumerate(self.cg_betas):
                    # eg: for 2 init_nodes and 3 steps, this is [2, 3, 4]
                    num_inputs_on_nodes = np.arange(self.search_space.get_num_steps(i_cg)) \
                                          + self.search_space.num_init_nodes
                    edge_norms = []
                    for i_node, num_inputs_on_node in enumerate(
                            num_inputs_on_nodes):
                        # eg: for node_0, it has edge_{0, 1} as inputs, there for start=0, end=2
                        start = num_inputs_on_nodes[i_node -
                                                    1] if i_node > 0 else 0
                        end = start + num_inputs_on_node

                        edge_norms.append(F.softmax(betas[start:end], dim=0))

                    edge_norms_list.append(torch.cat(edge_norms))

                arch_list = [
                    DartsArch(op_weights=op_weights, edge_norms=edge_norms)
                    for op_weights, edge_norms in zip(op_weights_list,
                                                      edge_norms_list)
                ]
            else:
                arch_list = [
                    DartsArch(op_weights=op_weights, edge_norms=None)
                    for op_weights in op_weights_list
                ]

            rollouts.append(
                DiffRollout(arch_list, sampled_list, logits_list,
                            self.search_space))
        return rollouts
Пример #12
0
def sample(load, out_file, n, save_plot, gpu, seed, dump_mode, prob_thresh,
           unique):
    LOGGER.info("CWD: %s", os.getcwd())
    LOGGER.info("CMD: %s", " ".join(sys.argv))

    setproctitle.setproctitle("awnas-sample load: {}; cwd: {}".format(
        load, os.getcwd()))

    # set gpu
    _set_gpu(gpu)
    device = torch.device(
        "cuda:{}".format(gpu) if torch.cuda.is_available() else "cpu")

    # set seed
    if seed is not None:
        LOGGER.info("Setting random seed: %d.", seed)
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)

    # create the directory for saving plots
    if save_plot is not None:
        save_plot = utils.makedir(save_plot)

    controller_path = os.path.join(load)
    # load the model on cpu
    controller = torch.load(controller_path, map_location=torch.device("cpu"))
    # then set the device
    controller.set_device(device)

    if prob_thresh or unique:
        sampled = 0
        ignored = 0
        rollouts = []
        genotypes = []
        while sampled < n:
            rollout_cands = controller.sample(n - sampled)
            for r in rollout_cands:
                assert "log_probs" in r.info
                log_prob = np.array([
                    utils.get_numpy(cg_lp) for cg_lp in r.info["log_probs"]
                ]).sum()
                if np.exp(log_prob) < prob_thresh:
                    ignored += 1
                    LOGGER.info("(ignored %d) Ignore arch prob %.3e (< %.3e)",
                                ignored, np.exp(log_prob), prob_thresh)
                elif r.genotype in genotypes:
                    ignored += 1
                    LOGGER.info("(ignored %d) Ignore duplicated arch", ignored)
                else:
                    sampled += 1
                    LOGGER.info("(choosed %d) Choose arch prob %.3e (>= %.3e)",
                                sampled, np.exp(log_prob), prob_thresh)
                    rollouts.append(r)
                    genotypes.append(r.genotype)
    else:
        rollouts = controller.sample(n)

    with open(out_file, "w") as of:
        for i, r in enumerate(rollouts):
            if save_plot is not None:
                r.plot_arch(filename=os.path.join(save_plot, str(i)),
                            label="Derive {}".format(i))
            if "log_probs" in r.info:
                log_prob = np.array([
                    utils.get_numpy(cg_lp) for cg_lp in r.info["log_probs"]
                ]).sum()
                of.write("# ---- Arch {} log_prob: {:.3f} prob: {:.3e} ----\n".
                         format(i, log_prob, np.exp(log_prob)))
            else:
                of.write("# ---- Arch {} ----\n".format(i))
            _dump(r, dump_mode, of)
            of.write("\n")