def _build2(self, s_in: Shape, s_out: Shape) -> ShapeList: """ build the network """ # find the search config if not os.path.isfile(self.search_config_path): self.search_config_path = Builder.find_net_config_path( self.search_config_path, pattern='search') # create a temporary search strategy tmp_s = RandomChoiceStrategy(max_epochs=1, name='__tmp__') sm = StrategyManager() assert len(sm.get_strategies_list( )) == 0, "can not load when there already is a search network" sm.add_strategy(tmp_s) sm.set_fixed_strategy_name('__tmp__') # create a search network search_net = Register.builder.load_from_config(self.search_config_path) assert isinstance(search_net, SearchUninasNetwork) search_net.build(s_in, s_out) search_net.set_forward_strategy(False) # set the architecture, get the config req_gene = "" if self.gene == 'random': search_net.forward_strategy() gene = sm.get_all_finalized_indices(unique=True, flat=True) self.model_name = "random(%s)" % str(gene) req_gene = " (%s)" % self.gene else: gene = split(self.gene, int) l0, l1 = len(sm.get_all_finalized_indices(unique=True)), len(gene) assert l0 == l1, "number of unique choices in the network (%d) must match length of the gene (%d)" % ( l0, l1) search_net.forward_strategy(fixed_arc=gene) config = search_net.config(finalize=True) # clean up sm.delete_strategy('__tmp__') del sm del search_net # build the actually used finalized network LoggerManager().get_logger().info( "Extracting architecture %s%s from the super-network" % (gene, req_gene)) self.net = Register.builder.from_config(config) return self.net.build(s_in, s_out)
def _initialize_weights(self, net: AbstractModule, logger: logging.Logger): assert isinstance( net, AbstractUninasNetwork ), "This initializer will not work with external networks!" search_config = Builder.find_net_config_path(self.path, pattern='search') checkpoint = CheckpointCallback.load_last_checkpoint(self.path) state_dict = checkpoint.get('state_dict') # figure out correct weights in super-network checkpoint if len(self.gene) > 0: log_headline(logger, "tmp network to track used params", target_len=80) sm = StrategyManager() tmp_s = RandomChoiceStrategy(max_epochs=1, name='__tmp__') assert len(sm.get_strategies_list( )) == 0, "can not load when there already is a search network" sm.add_strategy(tmp_s) sm.set_fixed_strategy_name('__tmp__') search_net = Builder().load_from_config(search_config) assert isinstance(search_net, SearchUninasNetwork) s_in, s_out = net.get_shape_in(), net.get_shape_out() search_net.build(s_in, s_out[0]) search_net.set_forward_strategy(False) search_net.forward_strategy(fixed_arc=self.gene) tracker = search_net.track_used_params( s_in.random_tensor(batch_size=2)) # tracker.print() logger.info(' > loading weights of gene %s from checkpoint "%s"' % (str(self.gene), self.path)) target_dict = net.state_dict() target_names = list(target_dict.keys()) new_dict = {} # add all stem and head weights, they are at the front of the dict and have pretty much the same name log_columns = [('shape in checkpoint', 'name in checkpoint', 'name in network', 'shape in network')] for k, v in state_dict.items(): if '.stem.' in k or '.heads.' in k: tn = target_names.pop(0) ts = target_dict[tn].shape log_columns.append( (str(list(v.shape)), k, tn, str(list(ts)))) n = k.replace('net.', '', 1) assert n == tn new_dict[n] = v # add all cell weights, can generally not compare names, only shapes for i, tracker_cell_entry in enumerate(tracker.get_cells()): for entry in tracker_cell_entry.get_pareto_best(): tn = target_names.pop(0) ts = target_dict[tn].shape log_columns.append((str(list(entry.shape)), entry.name, tn, str(list(ts)))) assert entry.shape == ts,\ 'Mismatching shapes for "%s" and "%s", is the gene correct?' % (entry.name, tn) new_dict[tn] = state_dict[entry.name] # log matches, load log_in_columns(logger, log_columns, add_bullets=True) net.load_state_dict(new_dict, strict=self.strict) # clean up del search_net sm.delete_strategy('__tmp__') del sm # simply load else: logger.info(' > simply loading state_dict') net.load_state_dict(state_dict, strict=self.strict)