def extend_args(cls, args_list: [str]): """ allow modifying the arguments list before other classes' arguments are dynamically added this should be used sparsely, as it is hard to keep track of """ # find last cls_network_body super().extend_args(args_list) # first find the correct config path, which is in all_args, enable short names (not only full paths) config_path = find_in_args_list( args_list, ['{cls_network}.config_path', '%s.config_path' % cls.__name__]) config_path = Builder.find_net_config_path(config_path) # extract used classes from the network config file, add them to the current task config if missing used_classes = Builder().find_classes_in_config(config_path) network_name = used_classes['cls_network_body'][0] cls_network_body = Register.network_bodies.get(network_name) optional_meta = [ m.argument.name for m in cls_network_body.meta_args_to_add() if m.optional_for_loading ] print( '\tbuilding a new net (config_only=False), added missing args from the network config file' ) for cls_n in ['cls_network_body'] + optional_meta: cls_c = find_in_args_list(args_list, [cls_n]) if cls_c is None or len(cls_c) == 0: cls_v = ', '.join(used_classes[cls_n]) print('\t %s -> %s' % (cls_n, cls_v)) args_list.append('--%s=%s' % (cls_n, cls_v))
def from_args(cls, args: Namespace, index=None) -> 'RetrainUninasNetwork': """ :param args: global argparse namespace :param index: index for the args """ all_parsed = cls._all_parsed_arguments(args, index=index) config_path = all_parsed.pop('config_path') config_path = Builder.find_net_config_path(config_path) net = Register.builder.load_from_config(config_path) return cls(model_name=Builder.net_config_name(config_path), net=net, **all_parsed)
def _build2(self, s_in: Shape, s_out: Shape) -> ShapeList: """ build the network """ # find the search config if not os.path.isfile(self.search_config_path): self.search_config_path = Builder.find_net_config_path( self.search_config_path, pattern='search') # create a temporary search strategy tmp_s = RandomChoiceStrategy(max_epochs=1, name='__tmp__') sm = StrategyManager() assert len(sm.get_strategies_list( )) == 0, "can not load when there already is a search network" sm.add_strategy(tmp_s) sm.set_fixed_strategy_name('__tmp__') # create a search network search_net = Register.builder.load_from_config(self.search_config_path) assert isinstance(search_net, SearchUninasNetwork) search_net.build(s_in, s_out) search_net.set_forward_strategy(False) # set the architecture, get the config req_gene = "" if self.gene == 'random': search_net.forward_strategy() gene = sm.get_all_finalized_indices(unique=True, flat=True) self.model_name = "random(%s)" % str(gene) req_gene = " (%s)" % self.gene else: gene = split(self.gene, int) l0, l1 = len(sm.get_all_finalized_indices(unique=True)), len(gene) assert l0 == l1, "number of unique choices in the network (%d) must match length of the gene (%d)" % ( l0, l1) search_net.forward_strategy(fixed_arc=gene) config = search_net.config(finalize=True) # clean up sm.delete_strategy('__tmp__') del sm del search_net # build the actually used finalized network LoggerManager().get_logger().info( "Extracting architecture %s%s from the super-network" % (gene, req_gene)) self.net = Register.builder.from_config(config) return self.net.build(s_in, s_out)
def get_network(config_path: str, input_shape: Shape, output_shape: Shape, weights_path: str = None) -> AbstractUninasNetwork: """ create a network (model) from a config file, optionally load weights """ builder = Builder() # get a new network network = builder.load_from_config(Builder.find_net_config_path(config_path)) network = AbstractUninasNetwork(model_name="standalone", net=network, checkpoint_path="", assert_output_match=True) network.build(s_in=input_shape, s_out=output_shape) # load network weights; they are saved from a method, so the keys have to be mapped accordingly if isinstance(weights_path, str): CheckpointCallback.load_network(weights_path, network, num_replacements=1) return network
def from_args(cls, args: Namespace, index=None) -> 'RetrainInsertConfigUninasNetwork': """ :param args: global argparse namespace :param index: argument index """ all_parsed = cls._all_parsed_arguments(args, index=index) config_path = Builder.find_net_config_path( all_parsed.pop('config_path')) net = cls._parsed_meta_argument(Register.network_bodies, 'cls_network_body', args, index=index) net = net.search_network_from_args(args, index=index) net.add_cells_from_config(Register.builder.load_config(config_path)) return cls(model_name=Builder.net_config_name(config_path), net=net, **all_parsed)
def _initialize_weights(self, net: AbstractModule, logger: logging.Logger): assert isinstance( net, AbstractUninasNetwork ), "This initializer will not work with external networks!" search_config = Builder.find_net_config_path(self.path, pattern='search') checkpoint = CheckpointCallback.load_last_checkpoint(self.path) state_dict = checkpoint.get('state_dict') # figure out correct weights in super-network checkpoint if len(self.gene) > 0: log_headline(logger, "tmp network to track used params", target_len=80) sm = StrategyManager() tmp_s = RandomChoiceStrategy(max_epochs=1, name='__tmp__') assert len(sm.get_strategies_list( )) == 0, "can not load when there already is a search network" sm.add_strategy(tmp_s) sm.set_fixed_strategy_name('__tmp__') search_net = Builder().load_from_config(search_config) assert isinstance(search_net, SearchUninasNetwork) s_in, s_out = net.get_shape_in(), net.get_shape_out() search_net.build(s_in, s_out[0]) search_net.set_forward_strategy(False) search_net.forward_strategy(fixed_arc=self.gene) tracker = search_net.track_used_params( s_in.random_tensor(batch_size=2)) # tracker.print() logger.info(' > loading weights of gene %s from checkpoint "%s"' % (str(self.gene), self.path)) target_dict = net.state_dict() target_names = list(target_dict.keys()) new_dict = {} # add all stem and head weights, they are at the front of the dict and have pretty much the same name log_columns = [('shape in checkpoint', 'name in checkpoint', 'name in network', 'shape in network')] for k, v in state_dict.items(): if '.stem.' in k or '.heads.' in k: tn = target_names.pop(0) ts = target_dict[tn].shape log_columns.append( (str(list(v.shape)), k, tn, str(list(ts)))) n = k.replace('net.', '', 1) assert n == tn new_dict[n] = v # add all cell weights, can generally not compare names, only shapes for i, tracker_cell_entry in enumerate(tracker.get_cells()): for entry in tracker_cell_entry.get_pareto_best(): tn = target_names.pop(0) ts = target_dict[tn].shape log_columns.append((str(list(entry.shape)), entry.name, tn, str(list(ts)))) assert entry.shape == ts,\ 'Mismatching shapes for "%s" and "%s", is the gene correct?' % (entry.name, tn) new_dict[tn] = state_dict[entry.name] # log matches, load log_in_columns(logger, log_columns, add_bullets=True) net.load_state_dict(new_dict, strict=self.strict) # clean up del search_net sm.delete_strategy('__tmp__') del sm # simply load else: logger.info(' > simply loading state_dict') net.load_state_dict(state_dict, strict=self.strict)
def make_from_single_dir(cls, path: str, space_name: str, arch_index: int) -> MiniResult: """ creating a mini result by parsing a training process """ # find gene and dataset in the task config task_configs = find_all_files(path, extension=name_task_config) assert len(task_configs) == 1 with open(task_configs[0]) as config_file: config = json.load(config_file) gene = config.get('{cls_network}.gene') gene = split(gene, int) data_set = get_dataset_from_json(task_configs[0]) data_set_name = data_set.__class__.__name__ # find loss and acc in the tensorboard files average_last = 5 metric_accuracy_train, metric_loss_train = "train/accuracy/1", "train/loss" metric_accuracy_test, metric_loss_test = "test/accuracy/1", "test/loss" tb_files = find_tb_files(path) assert len(tb_files) > 0 events = read_event_files(tb_files) loss_train = events.get(metric_loss_train, None) loss_test = events.get(metric_loss_test, None) assert (loss_train is not None) and (loss_test is not None) accuracy_train = events.get(metric_accuracy_train, None) accuracy_test = events.get(metric_accuracy_test, None) assert (accuracy_train is not None) and (accuracy_test is not None) # figure out params and flops by building the network net_config_path = Builder.find_net_config_path(path) network = get_network(net_config_path, data_set.get_data_shape(), data_set.get_label_shape()) # figure out latency at some point pass # return result return MiniResult( arch_index=arch_index, arch_str="%s(%s)" % (space_name, ", ".join([str(g) for g in gene])), arch_tuple=tuple(gene), params={data_set_name: network.get_num_parameters()}, flops={data_set_name: network.profile_macs()}, latency={data_set_name: -1}, loss={ data_set_name: { 'train': np.mean([v.value for v in loss_train[-average_last:]]), 'test': np.mean([v.value for v in loss_test[-average_last:]]), } }, acc1={ data_set_name: { 'train': np.mean([v.value for v in accuracy_train[-average_last:]]), 'test': np.mean([v.value for v in accuracy_test[-average_last:]]), } }, )
self.node.print(indent=0) def visualize_config(config: dict, save_path: str): save_path = replace_standard_paths(save_path) cfg_path = Builder.save_config(config, replace_standard_paths('{path_tmp}/viz/'), 'viz') exp = Main.new_task(run_config, args_changes={ '{cls_data}.fake': True, '{cls_data}.batch_size_train': 4, '{cls_task}.is_test_run': True, '{cls_task}.save_dir': '{path_tmp}/viz/task/', '{cls_task}.save_del_old': True, "{cls_network}.config_path": cfg_path, }) net = exp.get_method().get_network() vt = VizTree(net) vt.print() vt.plot(save_path + 'net', add_subgraphs=True) print('Saved cell viz to %s' % save_path) def visualize_file(config_path: str, save_dir: str): config_name = config_path.split('/')[-1].split('.')[0] save_path = '%s%s/' % (save_dir, config_name) config = Builder.load_config(config_path) visualize_config(config, save_path) if __name__ == '__main__': visualize_file(Builder.find_net_config_path('MobileNetV2'), '{path_tmp}/viz/')
def visualize_config(config: dict, save_path: str): save_path = replace_standard_paths(save_path) cfg_path = Builder.save_config(config, replace_standard_paths('{path_tmp}/viz/'), 'viz') exp = Main.new_task(run_config, args_changes={ '{cls_data}.fake': True, '{cls_data}.batch_size_train': 2, '{cls_task}.is_test_run': True, '{cls_task}.save_dir': '{path_tmp}/viz/task/', '{cls_task}.save_del_old': True, "{cls_task}.note": "viz", "{cls_network}.config_path": cfg_path, }) net = exp.get_method().get_network() for s in ['n', 'r']: for cell in net.get_cells(): if cell.name.startswith(s): visualize_cell(cell, save_path, s) break print('Saved cell viz to %s' % save_path) def visualize_file(config_path: str, save_dir: str): config_name_ = Builder.net_config_name(config_path) save_path = save_dir+config_name_+'/' config = Builder.load_config(config_path) visualize_config(config, save_path) if __name__ == '__main__': visualize_file(Builder.find_net_config_path('DARTS_V1'), '{path_tmp}/viz/')