def load_state_dict(self, state_dict)->None: for c, cs in utils.zip_eq(self.cell_descs(), state_dict['cell_descs']): c.load_state_dict(cs) for stem, state in utils.zip_eq(self.model_stems, state_dict['model_stems']): stem.load_state_dict(state) self.pool_op.load_state_dict(state_dict['pool_op']) self.logits_op.load_state_dict(state_dict['logits_op'])
def _flatten_ops_alphas(self): # Create list of (alpha, input_id, op_desc), sort them, select top k. # Here op should be nn.Sequence of sg followed by primitive. # First for loop gets edge and associated alphas. # Second for loop gets op and associated alpha. return ((a, i, op[1]) # op is nn.Sequence of stop grad and primitive op \ for edge_alphas, i, edge in \ zip_eq(self._alphas[0], range(self.desc.in_len), self._edges) \ for a, op in zip_eq(edge_alphas, edge))
def forward(self, x:List[Tensor]): assert not isinstance(x, torch.Tensor) s = 0.0 # apply each input in the list to associated edge for i, (xi, edge) in enumerate(zip_eq(x, self._edges)): # apply input to each primitive within edge # TODO: is avg better idea than sum here? sum can explode as # number of primitives goes up s = sum(a * op(xi) for a, op in zip_eq(self._alphas[0][i], edge)) + s return self._sf(s)
def load_state_dict(self, state_dict) -> None: assert self.id == state_dict['id'] assert self.cell_type == state_dict['cell_type'] for s, ss in utils.zip_eq(self.stems, state_dict['stems']): s.load_state_dict(ss) self.stem_shapes = state_dict['stem_shapes'] for n, ns in utils.zip_eq(self.nodes(), state_dict['nodes']): n.load_state_dict(ns) self.node_shapes = state_dict['node_shapes'] self.post_op.load_state_dict(state_dict['post_op']) self.out_shape = state_dict['out_shape']
def _hessian_vector_product(self, dw, x, y, epsilon_unit=1e-2): """ Implements equation 8 dw = dw` {L_val(w`, alpha)} w+ = w + eps * dw w- = w - eps * dw hessian = (dalpha {L_trn(w+, alpha)} -dalpha {L_trn(w-, alpha)})/(2*eps) eps = 0.01 / ||dw|| """ """scale epsilon with grad magnitude. The dw is a multiplier on RHS of eq 8. So this scalling is essential in making sure that finite differences approximation is not way off Below, we flatten each w, concate all and then take norm""" # TODO: is cat along dim 0 correct? dw_norm = torch.cat([w.view(-1) for w in dw]).norm() epsilon = epsilon_unit / dw_norm # w+ = w + epsilon * grad(w') with torch.no_grad(): for p, v in zip_eq(self._model.parameters(), dw): p += epsilon * v # Now that we have model with w+, we need to compute grads wrt alphas # This loss needs to be on train set, not validation set loss = _get_loss(self._model, self._lossfn, x, y) dalpha_plus = autograd.grad(loss, self._alphas) # dalpha{L_trn(w+)} # get model with w- and then compute grads wrt alphas # w- = w - eps*dw` with torch.no_grad(): for p, v in zip_eq(self._model.parameters(), dw): # we had already added dw above so sutracting twice gives w- p -= 2. * epsilon * v # similarly get dalpha_minus loss = _get_loss(self._model, self._lossfn, x, y) dalpha_minus = autograd.grad(loss, self._alphas) # reset back params to original values by adding dw with torch.no_grad(): for p, v in zip_eq(self._model.parameters(), dw): p += epsilon * v # apply eq 8, final difference to compute hessian h = [(p - m) / (2. * epsilon) for p, m in zip_eq(dalpha_plus, dalpha_minus)] return h
def ops(self) -> Iterator[Tuple['Op', float]]: # type: ignore return iter( sorted(zip_eq( self._ops, self._alphas[0] if self._alphas is not None else [math.nan for _ in range(len(self._ops))]), key=lambda t: t[1], reverse=True))
def _update_vmodel(self, x, y, lr: float, w_optim: Optimizer) -> None: """ Update vmodel with w' (main model has w) """ # TODO: should this loss be stored for later use? loss = _get_loss(self._model, self._lossfn, x, y) gradients = autograd.grad(loss, self._model_params()) """update weights in vmodel so we leave main model undisturbed The main technical difficulty computing w' without affecting alphas is that you can't simply do backward() and step() on loss because loss tracks alphas as well as w. So, we compute gradients using autograd and do manual sgd update.""" # TODO: other alternative may be to (1) copy model # (2) set require_grads = False on alphas # (3) loss and step on vmodel (4) set back require_grads = True with torch.no_grad(): # no need to track gradient for these operations for w, vw, g in zip( self._model_params(), self._vmodel_params(), gradients): # simulate momentum update on model but put this update in vmodel m = w_optim.state[w].get( 'momentum_buffer', 0.)*self._w_momentum vw.copy_(w - lr * (m + g + self._w_weight_decay*w)) # synchronize alphas for a, va in zip_eq(self._alphas, self._valphas): va.copy_(a)
def _backward_bilevel(self, x_train, y_train, x_valid, y_valid, lr, main_optim): """ Compute unrolled loss and backward its gradients """ # update vmodel with w', but leave alphas as-is # w' = w - lr * grad unrolled_model = self._unrolled_model(x_train, y_train, lr, main_optim) # compute loss on validation set for model with w' # wrt alphas. The autograd.grad is used instead of backward() # to avoid having to loop through params vloss = _get_loss(unrolled_model, self._lossfn, x_valid, y_valid) vloss.backward() dalpha = [v.grad for v in _get_alphas(unrolled_model)] dparams = [v.grad.data for v in unrolled_model.parameters()] hessian = self._hessian_vector_product(dparams, x_train, y_train) # dalpha we have is from the unrolled model so we need to # transfer those grades back to our main model # update final gradient = dalpha - xi*hessian # TODO: currently alphas lr is same as w lr with torch.no_grad(): for alpha, da, h in zip_eq(self._alphas, dalpha, hessian): alpha.grad = da - lr * h
def load_state_dict(self, state_dict: dict) -> None: optim_states = state_dict['optim_states'] sched_states = state_dict['sched_states'] for optim_sched, optim_state, sched_state in zip_eq( self, optim_states, sched_states): optim_sched.optim.load_state_dict(optim_state) if optim_sched.sched: assert sched_state is not None optim_sched.sched.load_state_dict(sched_state) else: assert sched_state is None
def load_state_dict(self, state_dict) -> None: self.trainables = state_dict['trainables'] c, cs = self.children, state_dict['children'] assert (c is None and cs is None) or \ (c is not None and cs is not None and len(c) == len(cs)) # TODO: when c and cs are both none, zip throws an error that the # first argument should be iterable if (c is None and cs is None): return for cx, csx in utils.zip_eq(c, cs): if cx is not None and csx is not None: cx.load_state_dict(csx)
def forward(self, x): self._activs = [op(x) for op in self._ops] numer = sum(w * activ for w, activ in zip_eq(self._alphas[0], self._activs)) denom = sum(self._alphas[0]) self.pt = torch.div(numer, denom) # register gradient hook if first time if self._is_first_call: self.pt.register_hook(self._save_grad()) self._is_first_call = False return self.pt
def forward(self, x): self._activs = [op(x) for op in self._ops] numer = sum(w * activ for w, activ in zip_eq(self._alphas[0], self._activs)) denom = sum(self._alphas[0]) self.pt = torch.div(numer, denom) # register hook to save gradients # NOTE: it has to be done every forward call # otherwise the hook doesn't remain registered # for subsequent loss.backward calls if self.training: self.pt.register_hook(self._save_grad()) return self.pt
def forward(self, x): assert self._sampled_weights is not None return sum(w * op(x) for w, op in zip_eq(self._sampled_weights, self._ops))
def ops(self) -> Iterator[Tuple['Op', float]]: # type: ignore return iter( sorted(zip_eq(self._ops, self._alphas[0]), key=lambda t: t[1], reverse=True))