def gradient(self, loss, return_grads=True, zero_grads=True): if zero_grads: self.zero_grad() _loss = loss + self._entropy_loss() _loss.backward() if return_grads: return utils.get_numpy(_loss), [ (k, v.grad.clone()) for k, v in self.named_parameters() ] return utils.get_numpy(_loss)
def summary(self, rollouts, log=False, log_prefix="", step=None): num = len(rollouts) logits_list = [[utils.get_numpy(logits) for logits in r.logits] for r in rollouts] _ss = self.search_space if self.gumbel_hard: cg_logprobs = [0. for _ in range(_ss.num_cell_groups)] cg_entros = [0. for _ in range(_ss.num_cell_groups)] for rollout, logits in zip(rollouts, logits_list): for cg_idx, (vec, cg_logits) in enumerate(zip(rollout.arch, logits)): prob = utils.softmax(cg_logits) logprob = np.log(prob) if self.gumbel_hard: inds = np.argmax(utils.get_numpy(vec.op_weights), axis=-1) cg_logprobs[cg_idx] += np.sum(logprob[range(len(inds)), inds]) cg_entros[cg_idx] += -(prob * logprob).sum() # mean across rollouts if self.gumbel_hard: cg_logprobs = [s / num for s in cg_logprobs] total_logprob = sum(cg_logprobs) cg_logprobs_str = ",".join( ["{:.2f}".format(n) for n in cg_logprobs]) cg_entros = [s / num for s in cg_entros] total_entro = sum(cg_entros) cg_entro_str = ",".join(["{:.2f}".format(n) for n in cg_entros]) if log: # maybe log the summary self.logger.info("%s%d rollouts: %s ENTROPY: %2f (%s)", log_prefix, num, "-LOG_PROB: %.2f (%s) ;" % (-total_logprob, cg_logprobs_str) \ if self.gumbel_hard else "", total_entro, cg_entro_str) if step is not None and not self.writer.is_none(): if self.gumbel_hard: self.writer.add_scalar("log_prob", total_logprob, step) self.writer.add_scalar("entropy", total_entro, step) stats = [(n + " ENTRO", entro) for n, entro in zip(_ss.cell_group_names, cg_entros)] if self.gumbel_hard: stats += [(n + " LOGPROB", logprob) for n, logprob in \ zip(_ss.cell_group_names, cg_logprobs)] return OrderedDict(stats)
def _eval_reward_func(self, data, cand_net, criterions, rollout, callback, kwargs): res = cand_net.eval_data(data, criterions=criterions, mode="train", **kwargs) rollout.set_perfs(OrderedDict(zip(self._all_perf_names, res))) if callback is not None: callback(rollout) # set reward to be the scalar rollout.set_perf(utils.get_numpy(rollout.get_perf(name="reward"))) return res
def gradient(self, loss, return_grads=True, zero_grads=True): if zero_grads: self.zero_grad() if self.inspect_hessian: for name, param in self.named_parameters(): max_eig = utils.torch_utils.max_eig_of_hessian(loss, param) self.logger.info("Max eigenvalue of Hessian of %s: %f", name, max_eig) self.inspect_hessian = False _loss = loss + self._entropy_loss() _loss.backward() if return_grads: return utils.get_numpy(_loss), [ (k, v.grad.clone()) for k, v in self.named_parameters() ] return utils.get_numpy(_loss)
def summary(self, rollouts, log=False, log_prefix="", step=None): # log the total negative log prob and the entropies, averaged across the rollouts # also the averaged info for each cell group cg_logprobs = np.mean(np.array([[utils.get_numpy(cg_lp).sum() \ for cg_lp in r.info["log_probs"]]\ for r in rollouts]), 0) total_logprob = cg_logprobs.sum() cg_logprobs_str = ",".join(["{:.2f}".format(n) for n in cg_logprobs]) cg_entros = np.mean(np.array([[utils.get_numpy(cg_e).sum() \ for cg_e in r.info["entropies"]]\ for r in rollouts]), 0) total_entro = cg_entros.sum() cg_entro_str = ",".join(["{:.2f}".format(n) for n in cg_entros]) num = len(rollouts) rewards = [r.get_perf("reward") for r in rollouts] if rewards[0] is not None: total_reward = np.mean(rewards) else: total_reward = None if log: # maybe log the summary self.logger.info( "%s%d rollouts: -LOG_PROB: %.2f (%s) ; ENTROPY: %2f (%s)", log_prefix, num, -total_logprob, cg_logprobs_str, total_entro, cg_entro_str) if step is not None and not self.writer.is_none(): self.writer.add_scalar("log_prob", total_logprob, step) self.writer.add_scalar("entropy", total_entro, step) # return the stats _ss = self.search_space stats = [(n + " LOGPROB", logprob) for n, logprob in zip(_ss.cell_group_names, cg_logprobs)] +\ [(n + " ENTRO", entro) for n, entro in zip(_ss.cell_group_names, cg_entros)] if total_reward is not None: stats += [("reward", total_reward)] return OrderedDict(stats)
def discretized_arch_and_prob(self): if self._discretized_arch is None: if self.arch[0].ndimension() == 2: self._discretized_arch, self._edge_probs = self.parse(self.sampled) else: assert self.arch[0].ndimension() == 3 self.logger.warning("Rollout batch size > 1, use logits instead of samples" "to parse the discretized arch.") # if multiple arch samples per step is used, (2nd dim of sampled/arch is # batch_size dim). use softmax(logits) to parse discretized arch self._discretized_arch, self._edge_probs = \ self.parse(utils.softmax(utils.get_numpy(self.logits))) return self._discretized_arch, self._edge_probs
def sample(self, n=1, batch_size=1): rollouts = [] for _ in range(n): arch_list = [] sampled_list = [] logits_list = [] for alpha in self.cg_alphas: if self.force_uniform: # cg_alpha parameters will not be in the graph alpha = torch.zeros_like(alpha) if batch_size > 1: expanded_alpha = alpha.reshape([alpha.shape[0], 1, alpha.shape[1]])\ .repeat([1, batch_size, 1])\ .reshape([-1, alpha.shape[-1]]) else: expanded_alpha = alpha if self.use_prob: # probability as sample sampled = F.softmax(expanded_alpha / self.gumbel_temperature, dim=-1) else: # gumbel sampling sampled, _ = utils.gumbel_softmax(expanded_alpha, self.gumbel_temperature, hard=False) if self.gumbel_hard: arch = utils.straight_through(sampled) else: arch = sampled if batch_size > 1: sampled = sampled.reshape([-1, batch_size, arch.shape[-1]]) arch = arch.reshape([-1, batch_size, arch.shape[-1]]) arch_list.append(arch) sampled_list.append(utils.get_numpy(sampled)) logits_list.append(utils.get_numpy(alpha)) rollouts.append( DiffRollout(arch_list, sampled_list, logits_list, self.search_space)) return rollouts
def test_diff_controller_use_prob(): from aw_nas import utils import numpy as np from aw_nas.controller import DiffController search_space = get_search_space(cls="cnn") device = "cuda" controller = DiffController(search_space, device, use_prob=True) assert controller.cg_alphas[0].shape == ( 14, len(search_space.shared_primitives)) rollouts = controller.sample(3) assert np.abs((utils.get_numpy(rollouts[0].sampled[0]) - utils.softmax(rollouts[0].logits[0])))\ .mean() < 1e-6 assert isinstance(rollouts[0].genotype, search_space.genotype_type)
def discretized_arch_and_prob(self): if self._discretized_arch is None: if self.arch[0].op_weights.ndimension() == 2: if self.arch[0].edge_norms is None: weights = self.sampled else: weights = [] for cg_sampled, (_, cg_edge_norms) in zip( self.sampled, self.arch): cg_edge_norms = utils.get_numpy(cg_edge_norms)[:, None] weights.append( utils.get_numpy(cg_sampled) * cg_edge_norms) self._discretized_arch, self._edge_probs = self.parse(weights) else: assert self.arch[0].op_weights.ndimension() == 3 self.logger.warning( "Rollout batch size > 1, use logits instead of samples" "to parse the discretized arch.") # if multiple arch samples per step is used, (2nd dim of sampled/arch is # batch_size dim). use softmax(logits) to parse discretized arch self._discretized_arch, self._edge_probs = \ self.parse(utils.softmax(utils.get_numpy(self.logits))) return self._discretized_arch, self._edge_probs
def _init_criterions(self, rollout_type): # criterion and forward keyword arguments for evaluating rollout in `evaluate_rollout` # support compare rollout assert "differentiable" in rollout_type # NOTE: only handle differentiable rollout differently self._reward_func = partial( self.objective.get_loss, add_controller_regularization=True, add_evaluator_regularization=False, ) self._reward_kwargs = {"detach_arch": False} self._scalar_reward_func = lambda *args, **kwargs: utils.get_numpy( self._reward_func(*args, **kwargs) ) self._perf_names = self.objective.perf_names() self._all_perf_names = utils.flatten_list(["reward", "loss", self._perf_names]) # criterion funcs for meta parameter training self._eval_loss_func = partial( self.objective.get_loss, add_controller_regularization=False, add_evaluator_regularization=True, ) # criterion funcs for log/report self._report_loss_funcs = [ partial( self.objective.get_loss_item, add_controller_regularization=False, add_evaluator_regularization=False, ), self.objective.get_perfs, ] self._criterions_related_attrs = [ "_reward_func", "_reward_kwargs", "_scalar_reward_func", "_reward_kwargs", "_perf_names", "_eval_loss_func", "_report_loss_funcs", ]
def sample(self, n=1, batch_size=1): rollouts = [] for _ in range(n): # op_weights.shape: [num_edges, [batch_size,] num_ops] # edge_norms.shape: [num_edges] do not have batch_size. op_weights_list = [] edge_norms_list = [] sampled_list = [] logits_list = [] for alphas in self.cg_alphas: if self.force_uniform: # cg_alpha parameters will not be in the graph # NOTE: `force_uniform` config does not affects edge_norms (betas), # if one wants a force_uniform search, keep `use_edge_normalization=False` alphas = torch.zeros_like(alphas) if batch_size > 1: expanded_alpha = alphas.reshape([alphas.shape[0], 1, alphas.shape[1]]) \ .repeat([1, batch_size, 1]) \ .reshape([-1, alphas.shape[-1]]) else: expanded_alpha = alphas if self.use_prob: # probability as sample sampled = F.softmax(expanded_alpha / self.gumbel_temperature, dim=-1) else: # gumbel sampling sampled, _ = utils.gumbel_softmax(expanded_alpha, self.gumbel_temperature, hard=False) if self.gumbel_hard: op_weights = utils.straight_through(sampled) else: op_weights = sampled if batch_size > 1: sampled = sampled.reshape( [-1, batch_size, op_weights.shape[-1]]) op_weights = op_weights.reshape( [-1, batch_size, op_weights.shape[-1]]) op_weights_list.append(op_weights) sampled_list.append(utils.get_numpy(sampled)) logits_list.append(utils.get_numpy(alphas)) if self.use_edge_normalization: for i_cg, betas in enumerate(self.cg_betas): # eg: for 2 init_nodes and 3 steps, this is [2, 3, 4] num_inputs_on_nodes = np.arange(self.search_space.get_num_steps(i_cg)) \ + self.search_space.num_init_nodes edge_norms = [] for i_node, num_inputs_on_node in enumerate( num_inputs_on_nodes): # eg: for node_0, it has edge_{0, 1} as inputs, there for start=0, end=2 start = num_inputs_on_nodes[i_node - 1] if i_node > 0 else 0 end = start + num_inputs_on_node edge_norms.append(F.softmax(betas[start:end], dim=0)) edge_norms_list.append(torch.cat(edge_norms)) arch_list = [ DartsArch(op_weights=op_weights, edge_norms=edge_norms) for op_weights, edge_norms in zip(op_weights_list, edge_norms_list) ] else: arch_list = [ DartsArch(op_weights=op_weights, edge_norms=None) for op_weights in op_weights_list ] rollouts.append( DiffRollout(arch_list, sampled_list, logits_list, self.search_space)) return rollouts
def sample(load, out_file, n, save_plot, gpu, seed, dump_mode, prob_thresh, unique): LOGGER.info("CWD: %s", os.getcwd()) LOGGER.info("CMD: %s", " ".join(sys.argv)) setproctitle.setproctitle("awnas-sample load: {}; cwd: {}".format( load, os.getcwd())) # set gpu _set_gpu(gpu) device = torch.device( "cuda:{}".format(gpu) if torch.cuda.is_available() else "cpu") # set seed if seed is not None: LOGGER.info("Setting random seed: %d.", seed) np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) # create the directory for saving plots if save_plot is not None: save_plot = utils.makedir(save_plot) controller_path = os.path.join(load) # load the model on cpu controller = torch.load(controller_path, map_location=torch.device("cpu")) # then set the device controller.set_device(device) if prob_thresh or unique: sampled = 0 ignored = 0 rollouts = [] genotypes = [] while sampled < n: rollout_cands = controller.sample(n - sampled) for r in rollout_cands: assert "log_probs" in r.info log_prob = np.array([ utils.get_numpy(cg_lp) for cg_lp in r.info["log_probs"] ]).sum() if np.exp(log_prob) < prob_thresh: ignored += 1 LOGGER.info("(ignored %d) Ignore arch prob %.3e (< %.3e)", ignored, np.exp(log_prob), prob_thresh) elif r.genotype in genotypes: ignored += 1 LOGGER.info("(ignored %d) Ignore duplicated arch", ignored) else: sampled += 1 LOGGER.info("(choosed %d) Choose arch prob %.3e (>= %.3e)", sampled, np.exp(log_prob), prob_thresh) rollouts.append(r) genotypes.append(r.genotype) else: rollouts = controller.sample(n) with open(out_file, "w") as of: for i, r in enumerate(rollouts): if save_plot is not None: r.plot_arch(filename=os.path.join(save_plot, str(i)), label="Derive {}".format(i)) if "log_probs" in r.info: log_prob = np.array([ utils.get_numpy(cg_lp) for cg_lp in r.info["log_probs"] ]).sum() of.write("# ---- Arch {} log_prob: {:.3f} prob: {:.3e} ----\n". format(i, log_prob, np.exp(log_prob))) else: of.write("# ---- Arch {} ----\n".format(i)) _dump(r, dump_mode, of) of.write("\n")