def _build2(self, s_in: Shape, s_out: Shape) -> ShapeList: """ build the network """ # find the search config if not os.path.isfile(self.search_config_path): self.search_config_path = Builder.find_net_config_path( self.search_config_path, pattern='search') # create a temporary search strategy tmp_s = RandomChoiceStrategy(max_epochs=1, name='__tmp__') sm = StrategyManager() assert len(sm.get_strategies_list( )) == 0, "can not load when there already is a search network" sm.add_strategy(tmp_s) sm.set_fixed_strategy_name('__tmp__') # create a search network search_net = Register.builder.load_from_config(self.search_config_path) assert isinstance(search_net, SearchUninasNetwork) search_net.build(s_in, s_out) search_net.set_forward_strategy(False) # set the architecture, get the config req_gene = "" if self.gene == 'random': search_net.forward_strategy() gene = sm.get_all_finalized_indices(unique=True, flat=True) self.model_name = "random(%s)" % str(gene) req_gene = " (%s)" % self.gene else: gene = split(self.gene, int) l0, l1 = len(sm.get_all_finalized_indices(unique=True)), len(gene) assert l0 == l1, "number of unique choices in the network (%d) must match length of the gene (%d)" % ( l0, l1) search_net.forward_strategy(fixed_arc=gene) config = search_net.config(finalize=True) # clean up sm.delete_strategy('__tmp__') del sm del search_net # build the actually used finalized network LoggerManager().get_logger().info( "Extracting architecture %s%s from the super-network" % (gene, req_gene)) self.net = Register.builder.from_config(config) return self.net.build(s_in, s_out)
def _initialize_weights(self, net: AbstractModule, logger: logging.Logger): assert isinstance( net, AbstractUninasNetwork ), "This initializer will not work with external networks!" search_config = Builder.find_net_config_path(self.path, pattern='search') checkpoint = CheckpointCallback.load_last_checkpoint(self.path) state_dict = checkpoint.get('state_dict') # figure out correct weights in super-network checkpoint if len(self.gene) > 0: log_headline(logger, "tmp network to track used params", target_len=80) sm = StrategyManager() tmp_s = RandomChoiceStrategy(max_epochs=1, name='__tmp__') assert len(sm.get_strategies_list( )) == 0, "can not load when there already is a search network" sm.add_strategy(tmp_s) sm.set_fixed_strategy_name('__tmp__') search_net = Builder().load_from_config(search_config) assert isinstance(search_net, SearchUninasNetwork) s_in, s_out = net.get_shape_in(), net.get_shape_out() search_net.build(s_in, s_out[0]) search_net.set_forward_strategy(False) search_net.forward_strategy(fixed_arc=self.gene) tracker = search_net.track_used_params( s_in.random_tensor(batch_size=2)) # tracker.print() logger.info(' > loading weights of gene %s from checkpoint "%s"' % (str(self.gene), self.path)) target_dict = net.state_dict() target_names = list(target_dict.keys()) new_dict = {} # add all stem and head weights, they are at the front of the dict and have pretty much the same name log_columns = [('shape in checkpoint', 'name in checkpoint', 'name in network', 'shape in network')] for k, v in state_dict.items(): if '.stem.' in k or '.heads.' in k: tn = target_names.pop(0) ts = target_dict[tn].shape log_columns.append( (str(list(v.shape)), k, tn, str(list(ts)))) n = k.replace('net.', '', 1) assert n == tn new_dict[n] = v # add all cell weights, can generally not compare names, only shapes for i, tracker_cell_entry in enumerate(tracker.get_cells()): for entry in tracker_cell_entry.get_pareto_best(): tn = target_names.pop(0) ts = target_dict[tn].shape log_columns.append((str(list(entry.shape)), entry.name, tn, str(list(ts)))) assert entry.shape == ts,\ 'Mismatching shapes for "%s" and "%s", is the gene correct?' % (entry.name, tn) new_dict[tn] = state_dict[entry.name] # log matches, load log_in_columns(logger, log_columns, add_bullets=True) net.load_state_dict(new_dict, strict=self.strict) # clean up del search_net sm.delete_strategy('__tmp__') del sm # simply load else: logger.info(' > simply loading state_dict') net.load_state_dict(state_dict, strict=self.strict)
class SearchUninasNetwork(AbstractUninasNetwork): def __init__(self, model_name: str, net: AbstractNetworkBody, do_forward_strategy=True, *args, **kwargs): super().__init__(model_name=model_name, net=net, *args, **kwargs) self.do_forward_strategy = do_forward_strategy # unnecessary line to remove "error" highlighting self._add_to_kwargs(do_forward_strategy=self.do_forward_strategy) self.strategy_manager = StrategyManager() self.strategies = None @classmethod def from_args(cls, args: Namespace, index=None, weight_strategies: Union[dict, str] = None)\ -> 'SearchUninasNetwork': """ :param args: global argparse namespace :param index: argument index :param weight_strategies: {strategy name: [cell indices]}, or name used for all, or None for defaults """ all_parsed = cls._all_parsed_arguments(args) cls_net = cls._parsed_meta_argument(Register.network_bodies, 'cls_network_body', args, index=index) net = cls_net.search_network_from_args(args, index=index, weight_strategies=weight_strategies) return cls(cls.__name__, net, **all_parsed) @classmethod def meta_args_to_add(cls) -> [MetaArgument]: """ list meta arguments to add to argparse for when this class is chosen, classes specified in meta arguments may have their own respective arguments """ return super().meta_args_to_add() + [ MetaArgument('cls_network_body', Register.network_bodies, help_name='network', allowed_num=1), ] def _build2(self, s_in: Shape, s_out: Shape) -> ShapeList: """ build the network """ s = self.net.build(s_in, s_out) self.strategies = self.strategy_manager.get_strategies_list() self.strategy_manager.build() return s def get_strategy_manager(self) -> StrategyManager: return self.strategy_manager def set_forward_strategy(self, forward_strategy: bool): self.do_forward_strategy = forward_strategy def get_forward_strategy(self) -> bool: return self.do_forward_strategy def forward(self, x: torch.Tensor, ws_kwargs: dict = None, **net_kwargs) -> [torch.Tensor]: """ forward first the weight strategy, then the network """ if self.do_forward_strategy: self.forward_strategy(**({} if ws_kwargs is None else ws_kwargs)) return super().forward(x, **net_kwargs) def forward_net(self, x: torch.Tensor, **net_kwargs) -> [torch.Tensor]: """ forward only the network """ return self.net(x, **net_kwargs) def forward_strategy(self, **ws_kwargs): """ forward only the weight strategy """ self.strategy_manager.forward(**ws_kwargs) def str(self, depth=0, **_) -> str: r = '{d}{name}(\n{ws},{net}\n{d}])'.format(**{ 'd': '{d}', 'd1': '{d1}', 'name': self.__class__.__name__, 'ws': '{d1}Strategies: [%s]' % ', '.join([ws.str() for ws in self.strategies]), 'net': self.net.str(depth=depth+1, max_depth=self.log_detail, **_), }) r = r.replace('{d}', '. '*depth).replace('{d1}', '. '*(depth+1)) return r def config(self, finalize=True, **_) -> dict: if finalize: return self.net.config(finalize=finalize, **_) return super().config(finalize=finalize, **_) def named_net_arc_parameters(self) -> (list, list): # all named parameters net_params, arc_params, duplicate_idx = list(self.net.named_parameters()), [], [] for ws in self.strategies: arc_params += list(ws.named_parameters()) # remove arc parameters from the network for an, ap in arc_params: for idx, (n, p) in enumerate(net_params): if ap is p: duplicate_idx.append(idx) for idx in sorted(duplicate_idx, reverse=True): net_params.pop(idx) return net_params, arc_params def track_used_params(self, x: torch.Tensor) -> Tracker: """ track which weights are used for the current architecture, and in which cell """ tracker = Tracker() is_train = self.training self.eval() handles = [] ws_modules = [] x = x.to(self.get_device()) # find all modules that have a weight strategy, add hooks for name, module in self.named_modules(): if hasattr(module, 'ws') and isinstance(module.ws, (AbstractWeightStrategy, StrategyManager)): ws_modules.append(module) for name2, m2 in module.named_modules(): if len(get_to_print(m2)) >= 1: handles.append(m2.register_forward_hook(Hook(tracker, 'net.%s.%s' % (name, name2)))) # forward pass with the current arc, all used weights are tracked self.forward_net(x) tracker.finalize() for h in handles: h.remove() self.train(is_train) return tracker @classmethod def get_space_tuple(cls, unique=True, flat=False) -> tuple: """ tuple of final topology """ return tuple(StrategyManager().get_all_finalized_indices(unique=unique, flat=flat))