예제 #1
0
    def _build2(self, s_in: Shape, s_out: Shape) -> ShapeList:
        """ build the network """

        # find the search config
        if not os.path.isfile(self.search_config_path):
            self.search_config_path = Builder.find_net_config_path(
                self.search_config_path, pattern='search')

        # create a temporary search strategy
        tmp_s = RandomChoiceStrategy(max_epochs=1, name='__tmp__')
        sm = StrategyManager()
        assert len(sm.get_strategies_list(
        )) == 0, "can not load when there already is a search network"
        sm.add_strategy(tmp_s)
        sm.set_fixed_strategy_name('__tmp__')

        # create a search network
        search_net = Register.builder.load_from_config(self.search_config_path)
        assert isinstance(search_net, SearchUninasNetwork)
        search_net.build(s_in, s_out)
        search_net.set_forward_strategy(False)

        # set the architecture, get the config
        req_gene = ""
        if self.gene == 'random':
            search_net.forward_strategy()
            gene = sm.get_all_finalized_indices(unique=True, flat=True)
            self.model_name = "random(%s)" % str(gene)
            req_gene = " (%s)" % self.gene
        else:
            gene = split(self.gene, int)
        l0, l1 = len(sm.get_all_finalized_indices(unique=True)), len(gene)
        assert l0 == l1, "number of unique choices in the network (%d) must match length of the gene (%d)" % (
            l0, l1)
        search_net.forward_strategy(fixed_arc=gene)
        config = search_net.config(finalize=True)

        # clean up
        sm.delete_strategy('__tmp__')
        del sm
        del search_net

        # build the actually used finalized network
        LoggerManager().get_logger().info(
            "Extracting architecture %s%s from the super-network" %
            (gene, req_gene))
        self.net = Register.builder.from_config(config)
        return self.net.build(s_in, s_out)
예제 #2
0
    def _initialize_weights(self, net: AbstractModule, logger: logging.Logger):
        assert isinstance(
            net, AbstractUninasNetwork
        ), "This initializer will not work with external networks!"
        search_config = Builder.find_net_config_path(self.path,
                                                     pattern='search')

        checkpoint = CheckpointCallback.load_last_checkpoint(self.path)
        state_dict = checkpoint.get('state_dict')

        # figure out correct weights in super-network checkpoint
        if len(self.gene) > 0:
            log_headline(logger,
                         "tmp network to track used params",
                         target_len=80)
            sm = StrategyManager()
            tmp_s = RandomChoiceStrategy(max_epochs=1, name='__tmp__')
            assert len(sm.get_strategies_list(
            )) == 0, "can not load when there already is a search network"
            sm.add_strategy(tmp_s)
            sm.set_fixed_strategy_name('__tmp__')

            search_net = Builder().load_from_config(search_config)
            assert isinstance(search_net, SearchUninasNetwork)
            s_in, s_out = net.get_shape_in(), net.get_shape_out()
            search_net.build(s_in, s_out[0])
            search_net.set_forward_strategy(False)
            search_net.forward_strategy(fixed_arc=self.gene)
            tracker = search_net.track_used_params(
                s_in.random_tensor(batch_size=2))
            # tracker.print()

            logger.info(' > loading weights of gene %s from checkpoint "%s"' %
                        (str(self.gene), self.path))
            target_dict = net.state_dict()
            target_names = list(target_dict.keys())
            new_dict = {}

            # add all stem and head weights, they are at the front of the dict and have pretty much the same name
            log_columns = [('shape in checkpoint', 'name in checkpoint',
                            'name in network', 'shape in network')]
            for k, v in state_dict.items():
                if '.stem.' in k or '.heads.' in k:
                    tn = target_names.pop(0)
                    ts = target_dict[tn].shape
                    log_columns.append(
                        (str(list(v.shape)), k, tn, str(list(ts))))
                    n = k.replace('net.', '', 1)
                    assert n == tn
                    new_dict[n] = v

            # add all cell weights, can generally not compare names, only shapes
            for i, tracker_cell_entry in enumerate(tracker.get_cells()):
                for entry in tracker_cell_entry.get_pareto_best():
                    tn = target_names.pop(0)
                    ts = target_dict[tn].shape
                    log_columns.append((str(list(entry.shape)), entry.name, tn,
                                        str(list(ts))))
                    assert entry.shape == ts,\
                        'Mismatching shapes for "%s" and "%s", is the gene correct?' % (entry.name, tn)
                    new_dict[tn] = state_dict[entry.name]

            # log matches, load
            log_in_columns(logger, log_columns, add_bullets=True)
            net.load_state_dict(new_dict, strict=self.strict)

            # clean up
            del search_net
            sm.delete_strategy('__tmp__')
            del sm

        # simply load
        else:
            logger.info(' > simply loading state_dict')
            net.load_state_dict(state_dict, strict=self.strict)
예제 #3
0
class SearchUninasNetwork(AbstractUninasNetwork):

    def __init__(self, model_name: str, net: AbstractNetworkBody, do_forward_strategy=True, *args, **kwargs):
        super().__init__(model_name=model_name, net=net, *args, **kwargs)
        self.do_forward_strategy = do_forward_strategy  # unnecessary line to remove "error" highlighting
        self._add_to_kwargs(do_forward_strategy=self.do_forward_strategy)
        self.strategy_manager = StrategyManager()
        self.strategies = None

    @classmethod
    def from_args(cls, args: Namespace, index=None, weight_strategies: Union[dict, str] = None)\
            -> 'SearchUninasNetwork':
        """
        :param args: global argparse namespace
        :param index: argument index
        :param weight_strategies: {strategy name: [cell indices]}, or name used for all, or None for defaults
        """
        all_parsed = cls._all_parsed_arguments(args)
        cls_net = cls._parsed_meta_argument(Register.network_bodies, 'cls_network_body', args, index=index)
        net = cls_net.search_network_from_args(args, index=index, weight_strategies=weight_strategies)
        return cls(cls.__name__, net, **all_parsed)

    @classmethod
    def meta_args_to_add(cls) -> [MetaArgument]:
        """
        list meta arguments to add to argparse for when this class is chosen,
        classes specified in meta arguments may have their own respective arguments
        """
        return super().meta_args_to_add() + [
            MetaArgument('cls_network_body', Register.network_bodies, help_name='network', allowed_num=1),
        ]

    def _build2(self, s_in: Shape, s_out: Shape) -> ShapeList:
        """ build the network """
        s = self.net.build(s_in, s_out)
        self.strategies = self.strategy_manager.get_strategies_list()
        self.strategy_manager.build()
        return s

    def get_strategy_manager(self) -> StrategyManager:
        return self.strategy_manager

    def set_forward_strategy(self, forward_strategy: bool):
        self.do_forward_strategy = forward_strategy

    def get_forward_strategy(self) -> bool:
        return self.do_forward_strategy

    def forward(self, x: torch.Tensor, ws_kwargs: dict = None, **net_kwargs) -> [torch.Tensor]:
        """
        forward first the weight strategy, then the network
        """
        if self.do_forward_strategy:
            self.forward_strategy(**({} if ws_kwargs is None else ws_kwargs))
        return super().forward(x, **net_kwargs)

    def forward_net(self, x: torch.Tensor, **net_kwargs) -> [torch.Tensor]:
        """
        forward only the network
        """
        return self.net(x, **net_kwargs)

    def forward_strategy(self, **ws_kwargs):
        """
        forward only the weight strategy
        """
        self.strategy_manager.forward(**ws_kwargs)

    def str(self, depth=0, **_) -> str:
        r = '{d}{name}(\n{ws},{net}\n{d}])'.format(**{
            'd': '{d}',
            'd1': '{d1}',
            'name': self.__class__.__name__,
            'ws': '{d1}Strategies: [%s]' % ', '.join([ws.str() for ws in self.strategies]),
            'net': self.net.str(depth=depth+1, max_depth=self.log_detail, **_),
        })
        r = r.replace('{d}', '. '*depth).replace('{d1}', '. '*(depth+1))
        return r

    def config(self, finalize=True, **_) -> dict:
        if finalize:
            return self.net.config(finalize=finalize, **_)
        return super().config(finalize=finalize, **_)

    def named_net_arc_parameters(self) -> (list, list):
        # all named parameters
        net_params, arc_params, duplicate_idx = list(self.net.named_parameters()), [], []
        for ws in self.strategies:
            arc_params += list(ws.named_parameters())
        # remove arc parameters from the network
        for an, ap in arc_params:
            for idx, (n, p) in enumerate(net_params):
                if ap is p:
                    duplicate_idx.append(idx)
        for idx in sorted(duplicate_idx, reverse=True):
            net_params.pop(idx)
        return net_params, arc_params

    def track_used_params(self, x: torch.Tensor) -> Tracker:
        """
        track which weights are used for the current architecture,
        and in which cell
        """
        tracker = Tracker()
        is_train = self.training
        self.eval()
        handles = []
        ws_modules = []
        x = x.to(self.get_device())

        # find all modules that have a weight strategy, add hooks
        for name, module in self.named_modules():
            if hasattr(module, 'ws') and isinstance(module.ws, (AbstractWeightStrategy, StrategyManager)):
                ws_modules.append(module)
                for name2, m2 in module.named_modules():
                    if len(get_to_print(m2)) >= 1:
                        handles.append(m2.register_forward_hook(Hook(tracker, 'net.%s.%s' % (name, name2))))

        # forward pass with the current arc, all used weights are tracked
        self.forward_net(x)

        tracker.finalize()
        for h in handles:
            h.remove()
        self.train(is_train)
        return tracker

    @classmethod
    def get_space_tuple(cls, unique=True, flat=False) -> tuple:
        """ tuple of final topology """
        return tuple(StrategyManager().get_all_finalized_indices(unique=unique, flat=flat))