示例#1
0
 def load_state_dict(self, state_dict)->None:
     for c, cs in utils.zip_eq(self.cell_descs(), state_dict['cell_descs']):
         c.load_state_dict(cs)
     for stem, state in utils.zip_eq(self.model_stems, state_dict['model_stems']):
         stem.load_state_dict(state)
     self.pool_op.load_state_dict(state_dict['pool_op'])
     self.logits_op.load_state_dict(state_dict['logits_op'])
示例#2
0
 def _flatten_ops_alphas(self):
     # Create list of (alpha, input_id, op_desc), sort them, select top k.
     # Here op should be nn.Sequence of sg followed by primitive.
     # First for loop gets edge and associated alphas.
     # Second for loop gets op and associated alpha.
     return ((a, i, op[1])       # op is nn.Sequence of stop grad and primitive op                                       \
         for edge_alphas, i, edge in                                 \
             zip_eq(self._alphas[0], range(self.desc.in_len), self._edges)       \
         for a, op in zip_eq(edge_alphas, edge))
示例#3
0
    def forward(self, x:List[Tensor]):
        assert not isinstance(x, torch.Tensor)

        s = 0.0
        # apply each input in the list to associated edge
        for i, (xi, edge) in enumerate(zip_eq(x, self._edges)):
            # apply input to each primitive within edge
            # TODO: is avg better idea than sum here? sum can explode as
            #   number of primitives goes up
            s = sum(a * op(xi) for a, op in zip_eq(self._alphas[0][i], edge)) + s
        return self._sf(s)
示例#4
0
    def load_state_dict(self, state_dict) -> None:
        assert self.id == state_dict['id']
        assert self.cell_type == state_dict['cell_type']

        for s, ss in utils.zip_eq(self.stems, state_dict['stems']):
            s.load_state_dict(ss)
        self.stem_shapes = state_dict['stem_shapes']

        for n, ns in utils.zip_eq(self.nodes(), state_dict['nodes']):
            n.load_state_dict(ns)
        self.node_shapes = state_dict['node_shapes']

        self.post_op.load_state_dict(state_dict['post_op'])
        self.out_shape = state_dict['out_shape']
    def _hessian_vector_product(self, dw, x, y, epsilon_unit=1e-2):
        """
        Implements equation 8

        dw = dw` {L_val(w`, alpha)}
        w+ = w + eps * dw
        w- = w - eps * dw
        hessian = (dalpha {L_trn(w+, alpha)} -dalpha {L_trn(w-, alpha)})/(2*eps)
        eps = 0.01 / ||dw||
        """
        """scale epsilon with grad magnitude. The dw
        is a multiplier on RHS of eq 8. So this scalling is essential
        in making sure that finite differences approximation is not way off
        Below, we flatten each w, concate all and then take norm"""
        # TODO: is cat along dim 0 correct?
        dw_norm = torch.cat([w.view(-1) for w in dw]).norm()
        epsilon = epsilon_unit / dw_norm

        # w+ = w + epsilon * grad(w')
        with torch.no_grad():
            for p, v in zip_eq(self._model.parameters(), dw):
                p += epsilon * v

        # Now that we have model with w+, we need to compute grads wrt alphas
        # This loss needs to be on train set, not validation set
        loss = _get_loss(self._model, self._lossfn, x, y)
        dalpha_plus = autograd.grad(loss, self._alphas)  # dalpha{L_trn(w+)}

        # get model with w- and then compute grads wrt alphas
        # w- = w - eps*dw`
        with torch.no_grad():
            for p, v in zip_eq(self._model.parameters(), dw):
                # we had already added dw above so sutracting twice gives w-
                p -= 2. * epsilon * v

        # similarly get dalpha_minus
        loss = _get_loss(self._model, self._lossfn, x, y)
        dalpha_minus = autograd.grad(loss, self._alphas)

        # reset back params to original values by adding dw
        with torch.no_grad():
            for p, v in zip_eq(self._model.parameters(), dw):
                p += epsilon * v

        # apply eq 8, final difference to compute hessian
        h = [(p - m) / (2. * epsilon)
             for p, m in zip_eq(dalpha_plus, dalpha_minus)]
        return h
示例#6
0
 def ops(self) -> Iterator[Tuple['Op', float]]:  # type: ignore
     return iter(
         sorted(zip_eq(
             self._ops, self._alphas[0] if self._alphas is not None else
             [math.nan for _ in range(len(self._ops))]),
                key=lambda t: t[1],
                reverse=True))
示例#7
0
    def _update_vmodel(self, x, y, lr: float, w_optim: Optimizer) -> None:
        """ Update vmodel with w' (main model has w) """

        # TODO: should this loss be stored for later use?
        loss = _get_loss(self._model, self._lossfn, x, y)
        gradients = autograd.grad(loss, self._model_params())

        """update weights in vmodel so we leave main model undisturbed
        The main technical difficulty computing w' without affecting alphas is
        that you can't simply do backward() and step() on loss because loss
        tracks alphas as well as w. So, we compute gradients using autograd and
        do manual sgd update."""
        # TODO: other alternative may be to (1) copy model
        #   (2) set require_grads = False on alphas
        #   (3) loss and step on vmodel (4) set back require_grads = True
        with torch.no_grad():  # no need to track gradient for these operations
            for w, vw, g in zip(
                    self._model_params(), self._vmodel_params(), gradients):
                # simulate momentum update on model but put this update in vmodel
                m = w_optim.state[w].get(
                    'momentum_buffer', 0.)*self._w_momentum
                vw.copy_(w - lr * (m + g + self._w_weight_decay*w))

            # synchronize alphas
            for a, va in zip_eq(self._alphas, self._valphas):
                va.copy_(a)
    def _backward_bilevel(self, x_train, y_train, x_valid, y_valid, lr,
                          main_optim):
        """ Compute unrolled loss and backward its gradients """

        # update vmodel with w', but leave alphas as-is
        # w' = w - lr * grad
        unrolled_model = self._unrolled_model(x_train, y_train, lr, main_optim)

        # compute loss on validation set for model with w'
        # wrt alphas. The autograd.grad is used instead of backward()
        # to avoid having to loop through params
        vloss = _get_loss(unrolled_model, self._lossfn, x_valid, y_valid)
        vloss.backward()
        dalpha = [v.grad for v in _get_alphas(unrolled_model)]
        dparams = [v.grad.data for v in unrolled_model.parameters()]

        hessian = self._hessian_vector_product(dparams, x_train, y_train)

        # dalpha we have is from the unrolled model so we need to
        # transfer those grades back to our main model
        # update final gradient = dalpha - xi*hessian
        # TODO: currently alphas lr is same as w lr
        with torch.no_grad():
            for alpha, da, h in zip_eq(self._alphas, dalpha, hessian):
                alpha.grad = da - lr * h
示例#9
0
    def load_state_dict(self, state_dict: dict) -> None:
        optim_states = state_dict['optim_states']
        sched_states = state_dict['sched_states']

        for optim_sched, optim_state, sched_state in zip_eq(
                self, optim_states, sched_states):
            optim_sched.optim.load_state_dict(optim_state)
            if optim_sched.sched:
                assert sched_state is not None
                optim_sched.sched.load_state_dict(sched_state)
            else:
                assert sched_state is None
示例#10
0
    def load_state_dict(self, state_dict) -> None:
        self.trainables = state_dict['trainables']
        c, cs = self.children, state_dict['children']
        assert (c is None and cs is None) or \
                (c is not None and cs is not None and len(c) == len(cs))

        # TODO: when c and cs are both none, zip throws an error that the
        # first argument should be iterable
        if (c is None and cs is None):
            return
        for cx, csx in utils.zip_eq(c, cs):
            if cx is not None and csx is not None:
                cx.load_state_dict(csx)
示例#11
0
    def forward(self, x):
        self._activs = [op(x) for op in self._ops]
        numer = sum(w * activ
                    for w, activ in zip_eq(self._alphas[0], self._activs))
        denom = sum(self._alphas[0])
        self.pt = torch.div(numer, denom)

        # register gradient hook if first time
        if self._is_first_call:
            self.pt.register_hook(self._save_grad())
            self._is_first_call = False

        return self.pt
示例#12
0
    def forward(self, x):
        self._activs = [op(x) for op in self._ops]
        numer = sum(w * activ
                    for w, activ in zip_eq(self._alphas[0], self._activs))
        denom = sum(self._alphas[0])
        self.pt = torch.div(numer, denom)

        # register hook to save gradients
        # NOTE: it has to be done every forward call
        # otherwise the hook doesn't remain registered
        # for subsequent loss.backward calls
        if self.training:
            self.pt.register_hook(self._save_grad())

        return self.pt
示例#13
0
文件: gs_op.py 项目: wayne9qiu/archai
 def forward(self, x):
     assert self._sampled_weights is not None
     return sum(w * op(x)
                for w, op in zip_eq(self._sampled_weights, self._ops))
示例#14
0
文件: gs_op.py 项目: wayne9qiu/archai
 def ops(self) -> Iterator[Tuple['Op', float]]:  # type: ignore
     return iter(
         sorted(zip_eq(self._ops, self._alphas[0]),
                key=lambda t: t[1],
                reverse=True))