def parse(self, weights): """parse and get the discertized arch""" archs = [] edge_probs = [] for i_cg, (cg_weight, cg_logits) in enumerate(zip(weights, self.logits)): cg_probs = softmax(cg_logits) start = 0 n = self.search_space.num_init_nodes arch = [[], []] edge_prob = [] num_steps = self.search_space.get_num_steps(i_cg) for _ in range(num_steps): end = start + n w = cg_weight[start:end] probs = cg_probs[start:end] edges = sorted(range(n), key=lambda node_id: -max(w[node_id])) #pylint: disable=cell-var-from-loop edges = edges[:self.search_space.num_node_inputs] arch[0] += edges # from nodes op_lst = [np.argmax(w[edge]) for edge in edges] # ops edge_prob += ["{:.3f}".format(probs[edge][op_id]) \ for edge, op_id in zip(edges, op_lst)] arch[1] += op_lst n += 1 start = end archs.append(arch) edge_probs.append(edge_prob) return archs, edge_probs
def discretized_arch_and_prob(self): if self._discretized_arch is None: if self.arch[0].ndimension() == 2: self._discretized_arch, self._edge_probs = self.parse(self.sampled) else: assert self.arch[0].ndimension() == 3 self.logger.warning("Rollout batch size > 1, use logits instead of samples" "to parse the discretized arch.") # if multiple arch samples per step is used, (2nd dim of sampled/arch is # batch_size dim). use softmax(logits) to parse discretized arch self._discretized_arch, self._edge_probs = \ self.parse(utils.softmax(utils.get_numpy(self.logits))) return self._discretized_arch, self._edge_probs
def test_diff_controller_use_prob(): from aw_nas import utils import numpy as np from aw_nas.controller import DiffController search_space = get_search_space(cls="cnn") device = "cuda" controller = DiffController(search_space, device, use_prob=True) assert controller.cg_alphas[0].shape == ( 14, len(search_space.shared_primitives)) rollouts = controller.sample(3) assert np.abs((utils.get_numpy(rollouts[0].sampled[0]) - utils.softmax(rollouts[0].logits[0])))\ .mean() < 1e-6 assert isinstance(rollouts[0].genotype, search_space.genotype_type)
def summary(self, rollouts, log=False, log_prefix="", step=None): num = len(rollouts) logits_list = [[utils.get_numpy(logits) for logits in r.logits] for r in rollouts] _ss = self.search_space if self.gumbel_hard: cg_logprobs = [0. for _ in range(_ss.num_cell_groups)] cg_entros = [0. for _ in range(_ss.num_cell_groups)] for rollout, logits in zip(rollouts, logits_list): for cg_idx, (vec, cg_logits) in enumerate(zip(rollout.arch, logits)): prob = utils.softmax(cg_logits) logprob = np.log(prob) if self.gumbel_hard: inds = np.argmax(utils.get_numpy(vec.op_weights), axis=-1) cg_logprobs[cg_idx] += np.sum(logprob[range(len(inds)), inds]) cg_entros[cg_idx] += -(prob * logprob).sum() # mean across rollouts if self.gumbel_hard: cg_logprobs = [s / num for s in cg_logprobs] total_logprob = sum(cg_logprobs) cg_logprobs_str = ",".join( ["{:.2f}".format(n) for n in cg_logprobs]) cg_entros = [s / num for s in cg_entros] total_entro = sum(cg_entros) cg_entro_str = ",".join(["{:.2f}".format(n) for n in cg_entros]) if log: # maybe log the summary self.logger.info("%s%d rollouts: %s ENTROPY: %2f (%s)", log_prefix, num, "-LOG_PROB: %.2f (%s) ;" % (-total_logprob, cg_logprobs_str) \ if self.gumbel_hard else "", total_entro, cg_entro_str) if step is not None and not self.writer.is_none(): if self.gumbel_hard: self.writer.add_scalar("log_prob", total_logprob, step) self.writer.add_scalar("entropy", total_entro, step) stats = [(n + " ENTRO", entro) for n, entro in zip(_ss.cell_group_names, cg_entros)] if self.gumbel_hard: stats += [(n + " LOGPROB", logprob) for n, logprob in \ zip(_ss.cell_group_names, cg_logprobs)] return OrderedDict(stats)
def test_diff_rollout(tmp_path): import torch from aw_nas.common import get_search_space, DifferentiableRollout from aw_nas.utils import softmax ss = get_search_space(cls="cnn") k = sum(ss.num_init_nodes+i for i in range(ss.num_steps)) logits = [np.random.randn(k, len(ss.shared_primitives)) for _ in range(ss.num_cell_groups)] eps = 1e-20 sampled = arch = [torch.Tensor(softmax( cg_logits + -np.log(-np.log(np.random.rand(*cg_logits.shape)+eps)+eps))) for cg_logits in logits] rollout = DifferentiableRollout(arch, sampled, logits, search_space=ss) print("genotype: ", rollout.genotype) prefix = os.path.join(str(tmp_path), "cell") fnames = rollout.plot_arch(prefix, label="test plot") assert fnames == [(cn, prefix + "-{}.pdf".format(cn)) for cn in ss.cell_group_names]
def discretized_arch_and_prob(self): if self._discretized_arch is None: if self.arch[0].op_weights.ndimension() == 2: if self.arch[0].edge_norms is None: weights = self.sampled else: weights = [] for cg_sampled, (_, cg_edge_norms) in zip( self.sampled, self.arch): cg_edge_norms = utils.get_numpy(cg_edge_norms)[:, None] weights.append( utils.get_numpy(cg_sampled) * cg_edge_norms) self._discretized_arch, self._edge_probs = self.parse(weights) else: assert self.arch[0].op_weights.ndimension() == 3 self.logger.warning( "Rollout batch size > 1, use logits instead of samples" "to parse the discretized arch.") # if multiple arch samples per step is used, (2nd dim of sampled/arch is # batch_size dim). use softmax(logits) to parse discretized arch self._discretized_arch, self._edge_probs = \ self.parse(utils.softmax(utils.get_numpy(self.logits))) return self._discretized_arch, self._edge_probs
def parse(self, weights): """parse and get the discertized arch""" archs = [] edge_probs = [] for i_cg, (cg_weight, cg_logits) in enumerate(zip(weights, self.logits)): if self.search_space.derive_without_none_op: try: none_op_idx = self.search_space.cell_shared_primitives[ i_cg].index("none") cg_weight[:, none_op_idx] = -1 except ValueError: # "none" is not in primitives pass cg_probs = softmax(cg_logits) start = 0 n = self.search_space.num_init_nodes arch = [[], []] edge_prob = [] num_steps = self.search_space.get_num_steps(i_cg) for _ in range(num_steps): end = start + n w = cg_weight[start:end] probs = cg_probs[start:end] edges = sorted(range(n), key=lambda node_id: -max(w[node_id])) #pylint: disable=cell-var-from-loop edges = edges[:self.search_space.num_node_inputs] arch[0] += edges # from nodes op_lst = [np.argmax(w[edge]) for edge in edges] # ops edge_prob += ["{:.3f}".format(probs[edge][op_id]) \ for edge, op_id in zip(edges, op_lst)] arch[1] += op_lst n += 1 start = end archs.append(arch) edge_probs.append(edge_prob) return archs, edge_probs