def __init__(self, op_desc: OpDesc, arch_params: Optional[ArchParams], affine: bool): super().__init__() # assume last PRIMITIVE is 'none' assert DivOp.PRIMITIVES[-1] == 'none' conf = get_conf() trainer = conf['nas']['search']['divnas']['archtrainer'] finalizer = conf['nas']['search']['finalizer'] if trainer == 'noalpha' and finalizer == 'default': raise NotImplementedError( 'noalpha trainer is not implemented for the default finalizer') if trainer != 'noalpha': self._setup_arch_params(arch_params) else: self._alphas = None self._ops = nn.ModuleList() for primitive in DivOp.PRIMITIVES: op = Op.create(OpDesc(primitive, op_desc.params, in_len=1, trainables=None), affine=affine, arch_params=None) self._ops.append(op) # various state variables for diversity self._collect_activations = False self._forward_counter = 0 self._batch_activs = None
def train_test()->Metrics: conf = common.get_conf() conf_eval = conf['nas']['eval'] # region conf vars conf_loader = conf_eval['loader'] conf_trainer = conf_eval['trainer'] # endregion conf_trainer['validation']['freq']=1 conf_trainer['epochs'] = 1 conf_loader['train_batch'] = 128 conf_loader['test_batch'] = 4096 conf_loader['cutout'] = 0 conf_trainer['drop_path_prob'] = 0.0 conf_trainer['grad_clip'] = 0.0 conf_trainer['aux_weight'] = 0.0 Net = cifar10_models.resnet34 model = Net().to(torch.device('cuda')) # get data data_loaders = data.get_data(conf_loader) assert data_loaders.train_dl is not None and data_loaders.test_dl is not None trainer = Trainer(conf_trainer, model, None) trainer.fit(data_loaders) met = trainer.get_metrics() return met
def pre_fit(self, train_dl: DataLoader, val_dl: Optional[DataLoader]) -> None: super().pre_fit(train_dl, val_dl) # optimizers, schedulers needs to be recreated for each fit call # as they have state assert val_dl is not None conf = get_conf() self._train_batch = conf['nas']['search']['loader']['train_batch'] num_val_examples = len(val_dl) * self._train_batch num_cells = conf['nas']['search']['model_desc']['n_cells'] num_reduction_cells = conf['nas']['search']['model_desc']['n_reductions'] num_normal_cells = num_cells - num_reduction_cells num_primitives = len(XnasOp.PRIMITIVES) assert num_cells > 0 assert num_reduction_cells > 0 assert num_normal_cells > 0 assert num_primitives > 0 self._normal_cell_effective_t = num_val_examples * self._epochs * num_normal_cells self._reduction_cell_effective_t = num_val_examples * \ self._epochs * num_reduction_cells self._normal_cell_lr = ma.sqrt(2 * ma.log(num_primitives) / ( self._normal_cell_effective_t * self._grad_clip * self._grad_clip)) self._reduction_cell_lr = ma.sqrt(2 * ma.log(num_primitives) / ( self._reduction_cell_effective_t * self._grad_clip * self._grad_clip)) self._xnas_optim = _XnasOptimizer(self._normal_cell_lr, self._reduction_cell_lr, self._normal_cell_effective_t, self._reduction_cell_effective_t, self._train_batch, self._grad_clip, self._multi_optim, self._apex, self.model)
def __init__(self, conf_train: Config, model: nn.Module, checkpoint: Optional[CheckPoint]) -> None: super().__init__(conf_train, model, checkpoint) conf = get_conf() self._gs_num_sample = conf['nas']['search']['model_desc']['cell'][ 'gs']['num_sample']
def finalizers(self) -> Finalizers: conf = get_conf() finalizer = conf['nas']['search']['finalizer'] if finalizer == 'mi': return DivnasFinalizers() else: return super().finalizers()
def pre_step(self, x: Tensor, y: Tensor) -> None: super().pre_step(x, y) # TODO: is it a good idea to ensure model is in training mode here? conf = get_conf() gs_num_sample = conf['nas']['search']['gs']['num_sample'] # for each node in a cell, get the alphas of each incoming edge # concatenate them all together, sample from them via GS # push the resulting weights to the corresponding edge ops # for use in their respective forward for _, cell in enumerate(self.model.cells): for _, node in enumerate(cell.dag): # collect all alphas for all edges in to node node_alphas = [] for edge in node: if hasattr(edge._op, 'PRIMITIVES') and type( edge._op) == GsOp: node_alphas.extend(alpha for op, alpha in edge._op.ops()) # TODO: will creating a tensor from a list of tensors preserve the graph? node_alphas = torch.Tensor(node_alphas) if node_alphas.nelement() > 0: # sample ops via gumbel softmax sample_storage = [] for _ in range(gs_num_sample): sampled = F.gumbel_softmax(node_alphas, tau=1, hard=False, eps=1e-10, dim=-1) sample_storage.append(sampled) samples_summed = torch.sum(torch.stack(sample_storage, dim=0), dim=0) samples = samples_summed / torch.sum(samples_summed) # TODO: should we be normalizing the sampled weights? # TODO: do gradients blow up as number of samples increases? # send the sampled op weights to their respective edges # to be used in forward counter = 0 for _, edge in enumerate(node): if hasattr(edge._op, 'PRIMITIVES') and type( edge._op) == GsOp: this_edge_sampled_weights = samples[ counter:counter + len(edge._op.PRIMITIVES)] edge._op.set_op_sampled_weights( this_edge_sampled_weights) counter += len(edge._op.PRIMITIVES)
def build(self, model_desc:ModelDesc, search_iter:int)->None: # if this is not the first iteration, we add new node to each cell if search_iter > 0: self.add_node(model_desc) conf = get_conf() self._gs_num_sample = conf['nas']['search']['gs']['num_sample'] for cell_desc in model_desc.cell_descs(): self._build_cell(cell_desc, self._gs_num_sample)
def trainer_class(self) -> TArchTrainer: conf = get_conf() trainer = conf['nas']['search']['divnas']['archtrainer'] if trainer == 'bilevel': return BilevelArchTrainer elif trainer == 'noalpha': return ArchTrainer else: raise NotImplementedError
def finalizers(self)->Finalizers: conf = common.get_conf() finalizer = conf['nas']['search']['finalizer'] if not finalizer or finalizer == 'default': return Finalizers() elif finalizer == 'random': return RandomFinalizers() else: raise NotImplementedError
def finalize_node(self, node: nn.ModuleList, node_index: int, node_desc: NodeDesc, max_final_edges: int, *args, **kwargs) -> NodeDesc: conf = get_conf() gs_num_sample = conf['nas']['search']['model_desc']['cell']['gs'][ 'num_sample'] # gather the alphas of all edges in this node node_alphas = [] for edge in node: if hasattr(edge._op, 'PRIMITIVES') and type(edge._op) == GsOp: alphas = [alpha for op, alpha in edge._op.ops()] node_alphas.extend(alphas) # TODO: will creating a tensor from a list of tensors preserve the graph? node_alphas = torch.Tensor(node_alphas) assert node_alphas.nelement() > 0 # sample ops via gumbel softmax sample_storage = [] for _ in range(gs_num_sample): sampled = F.gumbel_softmax(node_alphas, tau=1, hard=True, eps=1e-10, dim=-1) sample_storage.append(sampled) samples_summed = torch.sum(torch.stack(sample_storage, dim=0), dim=0) # send the sampled op weights to their # respective edges to be used for edge level finalize selected_edges = [] counter = 0 for _, edge in enumerate(node): if hasattr(edge._op, 'PRIMITIVES') and type(edge._op) == GsOp: this_edge_sampled_weights = samples_summed[counter:counter + len(edge._op. PRIMITIVES)] counter += len(edge._op.PRIMITIVES) # finalize the edge if this_edge_sampled_weights.bool().any(): op_desc, _ = edge._op.finalize(this_edge_sampled_weights) new_edge = EdgeDesc(op_desc, edge.input_ids) selected_edges.append(new_edge) # delete excess edges if len(selected_edges) > max_final_edges: # since these are sample edges there is no ordering # amongst them so we just arbitrarily select a few selected_edges = selected_edges[:max_final_edges] return NodeDesc(selected_edges, node_desc.conv_params)
def finalize_model(self, model: Model, to_cpu=True, restore_device=True) -> ModelDesc: logger.pushd('finalize') # get config and train data loader # TODO: confirm this is correct in case you get silent bugs conf = get_conf() conf_loader = conf['nas']['search']['loader'] train_dl, val_dl, test_dl = get_data(conf_loader) # wrap all cells in the model self._divnas_cells: Dict[int, Divnas_Cell] = {} for _, cell in enumerate(model.cells): divnas_cell = Divnas_Cell(cell) self._divnas_cells[id(cell)] = divnas_cell # go through all edges in the DAG and if they are of divop # type then set them to collect activations sigma = conf['nas']['search']['divnas']['sigma'] for _, dcell in enumerate(self._divnas_cells.values()): dcell.collect_activations(DivOp, sigma) # now we need to run one evaluation epoch to collect activations # we do it on cpu otherwise we might run into memory issues # later we can redo the whole logic in pytorch itself # at the end of this each node in a cell will have the covariance # matrix of all incoming edges' ops model = model.cpu() model.eval() with torch.no_grad(): for _ in range(1): for _, (x, _) in enumerate(train_dl): _, _ = model(x), None # now you can go through and update the # node covariances in every cell for dcell in self._divnas_cells.values(): dcell.update_covs() logger.popd() return super().finalize_model(model, to_cpu, restore_device)
def update_alphas(self, eta: float, current_t: int, total_t: int, grad_clip: float): grad_flat = torch.flatten(self._grad) rewards = torch.tensor([ -torch.dot(grad_flat, torch.flatten(activ)) for activ in self._activs ]) exprewards = torch.exp(eta * rewards).cuda() # NOTE: Will this remain registered? self._alphas[0] = torch.mul(self._alphas[0], exprewards) # weak learner eviction conf = get_conf() to_evict = conf['nas']['search']['xnas']['to_evict'] if to_evict: theta = max(self._alphas[0]) * ma.exp(-2 * eta * grad_clip * (total_t - current_t)) assert len(self._ops) == self._alphas[0].shape[0] to_keep_mask = self._alphas[0] >= theta num_ops_kept = torch.sum(to_keep_mask).item() assert num_ops_kept > 0 # zero out the weights which are evicted self._alphas[0] = torch.mul(self._alphas[0], to_keep_mask) # save some debugging info expdir = get_expdir() filename = os.path.join(expdir, str(id(self)) + '.txt') # save debug info to file alphas = [ str(self._alphas[0][i].item()) for i in range(self._alphas[0].shape[0]) ] with open(filename, 'a') as f: f.write(str(alphas)) f.write('\n')