Example #1
0
    def __init__(self, op_desc: OpDesc, arch_params: Optional[ArchParams],
                 affine: bool):
        super().__init__()

        # assume last PRIMITIVE is 'none'
        assert DivOp.PRIMITIVES[-1] == 'none'

        conf = get_conf()
        trainer = conf['nas']['search']['divnas']['archtrainer']
        finalizer = conf['nas']['search']['finalizer']

        if trainer == 'noalpha' and finalizer == 'default':
            raise NotImplementedError(
                'noalpha trainer is not implemented for the default finalizer')

        if trainer != 'noalpha':
            self._setup_arch_params(arch_params)
        else:
            self._alphas = None

        self._ops = nn.ModuleList()
        for primitive in DivOp.PRIMITIVES:
            op = Op.create(OpDesc(primitive,
                                  op_desc.params,
                                  in_len=1,
                                  trainables=None),
                           affine=affine,
                           arch_params=None)
            self._ops.append(op)

        # various state variables for diversity
        self._collect_activations = False
        self._forward_counter = 0
        self._batch_activs = None
Example #2
0
def train_test()->Metrics:
    conf = common.get_conf()
    conf_eval = conf['nas']['eval']

    # region conf vars
    conf_loader       = conf_eval['loader']
    conf_trainer = conf_eval['trainer']
    # endregion

    conf_trainer['validation']['freq']=1
    conf_trainer['epochs'] = 1
    conf_loader['train_batch'] = 128
    conf_loader['test_batch'] = 4096
    conf_loader['cutout'] = 0
    conf_trainer['drop_path_prob'] = 0.0
    conf_trainer['grad_clip'] = 0.0
    conf_trainer['aux_weight'] = 0.0

    Net = cifar10_models.resnet34
    model = Net().to(torch.device('cuda'))

    # get data
    data_loaders = data.get_data(conf_loader)
    assert data_loaders.train_dl is not None and data_loaders.test_dl is not None

    trainer = Trainer(conf_trainer, model, None)
    trainer.fit(data_loaders)
    met = trainer.get_metrics()
    return met
Example #3
0
    def pre_fit(self, train_dl: DataLoader, val_dl: Optional[DataLoader]) -> None:
        super().pre_fit(train_dl, val_dl)

        # optimizers, schedulers needs to be recreated for each fit call
        # as they have state
        assert val_dl is not None

        conf = get_conf()
        self._train_batch = conf['nas']['search']['loader']['train_batch']
        num_val_examples = len(val_dl) * self._train_batch
        num_cells = conf['nas']['search']['model_desc']['n_cells']
        num_reduction_cells = conf['nas']['search']['model_desc']['n_reductions']
        num_normal_cells = num_cells - num_reduction_cells
        num_primitives = len(XnasOp.PRIMITIVES)

        assert num_cells > 0
        assert num_reduction_cells > 0
        assert num_normal_cells > 0
        assert num_primitives > 0

        self._normal_cell_effective_t = num_val_examples * self._epochs * num_normal_cells
        self._reduction_cell_effective_t = num_val_examples * \
            self._epochs * num_reduction_cells

        self._normal_cell_lr = ma.sqrt(2 * ma.log(num_primitives) / (
            self._normal_cell_effective_t * self._grad_clip * self._grad_clip))
        self._reduction_cell_lr = ma.sqrt(2 * ma.log(num_primitives) / (
            self._reduction_cell_effective_t * self._grad_clip * self._grad_clip))

        self._xnas_optim = _XnasOptimizer(self._normal_cell_lr, self._reduction_cell_lr, self._normal_cell_effective_t,
                                          self._reduction_cell_effective_t, self._train_batch, self._grad_clip, 
                                          self._multi_optim, self._apex, self.model)
Example #4
0
    def __init__(self, conf_train: Config, model: nn.Module,
                 checkpoint: Optional[CheckPoint]) -> None:
        super().__init__(conf_train, model, checkpoint)

        conf = get_conf()
        self._gs_num_sample = conf['nas']['search']['model_desc']['cell'][
            'gs']['num_sample']
Example #5
0
    def finalizers(self) -> Finalizers:
        conf = get_conf()
        finalizer = conf['nas']['search']['finalizer']

        if finalizer == 'mi':
            return DivnasFinalizers()
        else:
            return super().finalizers()
Example #6
0
    def pre_step(self, x: Tensor, y: Tensor) -> None:
        super().pre_step(x, y)

        # TODO: is it a good idea to ensure model is in training mode here?

        conf = get_conf()
        gs_num_sample = conf['nas']['search']['gs']['num_sample']

        # for each node in a cell, get the alphas of each incoming edge
        # concatenate them all together, sample from them via GS
        # push the resulting weights to the corresponding edge ops
        # for use in their respective forward

        for _, cell in enumerate(self.model.cells):
            for _, node in enumerate(cell.dag):
                # collect all alphas for all edges in to node
                node_alphas = []
                for edge in node:
                    if hasattr(edge._op, 'PRIMITIVES') and type(
                            edge._op) == GsOp:
                        node_alphas.extend(alpha
                                           for op, alpha in edge._op.ops())

                # TODO: will creating a tensor from a list of tensors preserve the graph?
                node_alphas = torch.Tensor(node_alphas)

                if node_alphas.nelement() > 0:
                    # sample ops via gumbel softmax
                    sample_storage = []
                    for _ in range(gs_num_sample):
                        sampled = F.gumbel_softmax(node_alphas,
                                                   tau=1,
                                                   hard=False,
                                                   eps=1e-10,
                                                   dim=-1)
                        sample_storage.append(sampled)

                    samples_summed = torch.sum(torch.stack(sample_storage,
                                                           dim=0),
                                               dim=0)
                    samples = samples_summed / torch.sum(samples_summed)

                    # TODO: should we be normalizing the sampled weights?
                    # TODO: do gradients blow up as number of samples increases?

                    # send the sampled op weights to their respective edges
                    # to be used in forward
                    counter = 0
                    for _, edge in enumerate(node):
                        if hasattr(edge._op, 'PRIMITIVES') and type(
                                edge._op) == GsOp:
                            this_edge_sampled_weights = samples[
                                counter:counter + len(edge._op.PRIMITIVES)]
                            edge._op.set_op_sampled_weights(
                                this_edge_sampled_weights)
                            counter += len(edge._op.PRIMITIVES)
Example #7
0
    def build(self, model_desc:ModelDesc, search_iter:int)->None:
        # if this is not the first iteration, we add new node to each cell
        if search_iter > 0:
            self.add_node(model_desc)

        conf = get_conf()
        self._gs_num_sample = conf['nas']['search']['gs']['num_sample']

        for cell_desc in model_desc.cell_descs():
            self._build_cell(cell_desc, self._gs_num_sample)
Example #8
0
    def trainer_class(self) -> TArchTrainer:
        conf = get_conf()
        trainer = conf['nas']['search']['divnas']['archtrainer']

        if trainer == 'bilevel':
            return BilevelArchTrainer
        elif trainer == 'noalpha':
            return ArchTrainer
        else:
            raise NotImplementedError
Example #9
0
    def finalizers(self)->Finalizers:
        conf = common.get_conf()
        finalizer = conf['nas']['search']['finalizer']

        if not finalizer or finalizer == 'default':
            return Finalizers()
        elif finalizer == 'random':
            return RandomFinalizers()
        else:
            raise NotImplementedError
Example #10
0
    def finalize_node(self, node: nn.ModuleList, node_index: int,
                      node_desc: NodeDesc, max_final_edges: int, *args,
                      **kwargs) -> NodeDesc:
        conf = get_conf()
        gs_num_sample = conf['nas']['search']['model_desc']['cell']['gs'][
            'num_sample']

        # gather the alphas of all edges in this node
        node_alphas = []
        for edge in node:
            if hasattr(edge._op, 'PRIMITIVES') and type(edge._op) == GsOp:
                alphas = [alpha for op, alpha in edge._op.ops()]
                node_alphas.extend(alphas)

        # TODO: will creating a tensor from a list of tensors preserve the graph?
        node_alphas = torch.Tensor(node_alphas)

        assert node_alphas.nelement() > 0

        # sample ops via gumbel softmax
        sample_storage = []
        for _ in range(gs_num_sample):
            sampled = F.gumbel_softmax(node_alphas,
                                       tau=1,
                                       hard=True,
                                       eps=1e-10,
                                       dim=-1)
            sample_storage.append(sampled)

        samples_summed = torch.sum(torch.stack(sample_storage, dim=0), dim=0)

        # send the sampled op weights to their
        # respective edges to be used for edge level finalize
        selected_edges = []
        counter = 0
        for _, edge in enumerate(node):
            if hasattr(edge._op, 'PRIMITIVES') and type(edge._op) == GsOp:
                this_edge_sampled_weights = samples_summed[counter:counter +
                                                           len(edge._op.
                                                               PRIMITIVES)]
                counter += len(edge._op.PRIMITIVES)
                # finalize the edge
                if this_edge_sampled_weights.bool().any():
                    op_desc, _ = edge._op.finalize(this_edge_sampled_weights)
                    new_edge = EdgeDesc(op_desc, edge.input_ids)
                    selected_edges.append(new_edge)

        # delete excess edges
        if len(selected_edges) > max_final_edges:
            # since these are sample edges there is no ordering
            # amongst them so we just arbitrarily select a few
            selected_edges = selected_edges[:max_final_edges]

        return NodeDesc(selected_edges, node_desc.conv_params)
Example #11
0
    def finalize_model(self,
                       model: Model,
                       to_cpu=True,
                       restore_device=True) -> ModelDesc:

        logger.pushd('finalize')

        # get config and train data loader
        # TODO: confirm this is correct in case you get silent bugs
        conf = get_conf()
        conf_loader = conf['nas']['search']['loader']
        train_dl, val_dl, test_dl = get_data(conf_loader)

        # wrap all cells in the model
        self._divnas_cells: Dict[int, Divnas_Cell] = {}
        for _, cell in enumerate(model.cells):
            divnas_cell = Divnas_Cell(cell)
            self._divnas_cells[id(cell)] = divnas_cell

        # go through all edges in the DAG and if they are of divop
        # type then set them to collect activations
        sigma = conf['nas']['search']['divnas']['sigma']
        for _, dcell in enumerate(self._divnas_cells.values()):
            dcell.collect_activations(DivOp, sigma)

        # now we need to run one evaluation epoch to collect activations
        # we do it on cpu otherwise we might run into memory issues
        # later we can redo the whole logic in pytorch itself
        # at the end of this each node in a cell will have the covariance
        # matrix of all incoming edges' ops
        model = model.cpu()
        model.eval()
        with torch.no_grad():
            for _ in range(1):
                for _, (x, _) in enumerate(train_dl):
                    _, _ = model(x), None
                    # now you can go through and update the
                    # node covariances in every cell
                    for dcell in self._divnas_cells.values():
                        dcell.update_covs()

        logger.popd()

        return super().finalize_model(model, to_cpu, restore_device)
Example #12
0
    def update_alphas(self, eta: float, current_t: int, total_t: int,
                      grad_clip: float):
        grad_flat = torch.flatten(self._grad)
        rewards = torch.tensor([
            -torch.dot(grad_flat, torch.flatten(activ))
            for activ in self._activs
        ])
        exprewards = torch.exp(eta * rewards).cuda()
        # NOTE: Will this remain registered?
        self._alphas[0] = torch.mul(self._alphas[0], exprewards)

        # weak learner eviction
        conf = get_conf()
        to_evict = conf['nas']['search']['xnas']['to_evict']
        if to_evict:
            theta = max(self._alphas[0]) * ma.exp(-2 * eta * grad_clip *
                                                  (total_t - current_t))
            assert len(self._ops) == self._alphas[0].shape[0]
            to_keep_mask = self._alphas[0] >= theta
            num_ops_kept = torch.sum(to_keep_mask).item()
            assert num_ops_kept > 0
            # zero out the weights which are evicted
            self._alphas[0] = torch.mul(self._alphas[0], to_keep_mask)

        # save some debugging info
        expdir = get_expdir()
        filename = os.path.join(expdir, str(id(self)) + '.txt')

        # save debug info to file
        alphas = [
            str(self._alphas[0][i].item())
            for i in range(self._alphas[0].shape[0])
        ]
        with open(filename, 'a') as f:
            f.write(str(alphas))
            f.write('\n')