def _build(self, s_in: Shape, s_out: Shape) -> ShapeList: """ build the network, count params, log, maybe load pretrained weights """ assert isinstance(s_out, Shape), "Attempting to build a network with an output that is not a Shape!" s_out_copy = s_out.copy(copy_id=True) self.shape_in = s_in.copy(copy_id=True) s_out_net = self._build2(s_in, s_out) LoggerManager().get_logger().info('Network built, it has %d parameters!' % self.get_num_parameters()) # validate output shape sizes assert isinstance(s_out_net, ShapeList), "The network must output a list of Shapes, one shape per head! (ShapeList)" for shape in s_out_net.shapes: if not s_out_copy == shape: text = "One or more output shapes mismatch: %s, expected: %s" % (s_out_net, s_out_copy) if self.assert_output_match: raise ValueError(text) else: LoggerManager().get_logger().warning(text) break # load weights? if len(self.checkpoint_path) > 0: path = CheckpointCallback.find_pretrained_weights_path(self.checkpoint_path, self.model_name, raise_missing=len(self.checkpoint_path) > 0) num_replacements = 1 if self.is_external() else 999 self.loaded_weights(CheckpointCallback.load_network(path, self.get_network(), num_replacements)) self.shape_out = s_out_net.shapes[0].copy(copy_id=True) self.shape_in_list = self.shape_in.shape self.shape_out_list = self.shape_out.shape return s_out_net
def _build(self, s_in: Shape, s_out: Shape) -> ShapeList: LoggerManager().get_logger().info('Building %s:' % self.__class__.__name__) rows = [('cell index', 'name', 'class', 'input shapes', '', 'output shapes', '#params')] def get_row(idx, name: str, obj: AbstractModule) -> tuple: s_in_str = obj.get_shape_in().str() s_inner = obj.get_cached('shape_inner') s_inner_str = '' if s_inner is None else s_inner.str() s_out_str = obj.get_shape_out().str() return str(idx), name, obj.__class__.__name__, s_in_str, s_inner_str, s_out_str, count_parameters(obj) s_out_data = s_out.copy() out_shapes = self.stem.build(s_in) final_out_shapes = [] rows.append(get_row('', '-', self.stem)) # cells and (aux) heads updated_cell_order = [] for i, cell_name in enumerate(self.cell_order): strategy_name, cell = self._get_cell(name=cell_name, cell_index=i) assert self.stem.num_outputs() == cell.num_inputs() == cell.num_outputs(), 'Cell does not fit the network!' updated_cell_order.append(cell.name) s_ins = out_shapes[-cell.num_inputs():] with StrategyManagerDefault(strategy_name): s_out = cell.build(s_ins.copy(), features_mul=self.features_mul, features_fixed=self.features_first_cell if i == 0 else -1) out_shapes.extend(s_out) rows.append(get_row(i, cell_name, cell)) self.cells.append(cell) # optional (aux) head after every cell head = self._head_positions.get(i, None) if head is not None: if head.weight > 0: final_out_shapes.append(head.build(s_out[-1], s_out_data)) rows.append(get_row('', '-', head)) else: LoggerManager().get_logger().info('not adding head after cell %d, weight <= 0' % i) del self._head_positions[i] else: assert i != len(self.cell_order) - 1, "Must have a head after the final cell" # remove heads that are impossible to add for i in self._head_positions.keys(): if i >= len(self.cells): LoggerManager().get_logger().warning('Can not add a head after cell %d which does not exist, deleting the head!' % i) head = self._head_positions.get(i) for j, head2 in enumerate(self.heads): if head is head2: self.heads.__delitem__(j) break s_out = ShapeList(final_out_shapes) rows.append(('complete network', '', '', self.get_shape_in().str(), '', s_out.str(), count_parameters(self))) log_in_columns(LoggerManager().get_logger(), rows, start_space=4) self.set(cell_order=updated_cell_order) return s_out
def verify(): logger = LoggerManager().get_logger() parser = argparse.ArgumentParser('get_network') parser.add_argument('--config_path', type=str, default='FairNasC') parser.add_argument('--weights_path', type=str, default='{path_tmp}/s3/') parser.add_argument('--data_dir', type=str, default='{path_data}/ImageNet_ILSVRC2012/') parser.add_argument('--data_batch_size', type=int, default=128) parser.add_argument('--data_num_workers', type=int, default=8) parser.add_argument('--num_batches', type=int, default=-1, help='>0 to stop early') args, _ = parser.parse_known_args() # ImageNet with default augmentations / cropping data_set = get_imagenet( data_dir=args.data_dir, num_workers=args.data_num_workers, batch_size=args.data_batch_size, aug_dict={ "cls_augmentations": "TimmImagenetAug", "DartsImagenetAug#0.crop_size": 224, }, ) # network network = get_network(args.config_path, data_set.get_data_shape(), data_set.get_label_shape(), args.weights_path) network.eval() network = network.cuda() # measure the accuracy top1, top5, num_samples = 0, 0, 0 with torch.no_grad(): for i, (data, targets) in enumerate(data_set.test_loader()): if i >= args.num_batches > 0: break outputs = network(data.cuda()) t1, t5 = accuracy(outputs, targets.cuda(), topk=(1, 5)) n = data.size(0) top1 += t1 * n top5 += t5 * n num_samples += n logger.info('results:') logger.info('\ttested images: %d' % num_samples) logger.info('\ttop1: %.4f (%d/%d)' % (top1 / num_samples, top1, num_samples)) logger.info('\ttop5: %.4f (%d/%d)' % (top5 / num_samples, top5, num_samples))
def find_last_checkpoint_path(cls, save_dir: str, index=0, try_general_checkpoint=True) -> str: """ attempt finding the checkpoint path in a dir, if 'save_dir' is a file, return it if there is a general checkpoint and 'try_general_checkpoint', return its path otherwise try finding the most recent checkpoint of the CheckpointCallback with index 'index' """ save_dir = replace_standard_paths(save_dir) # try as path and general path if os.path.isfile(save_dir): return save_dir if try_general_checkpoint and os.path.isfile( cls._general_checkpoint_file(save_dir)): return cls._general_checkpoint_file(save_dir) # try by index lst = sorted(cls.list_infos(save_dir, index), key=lambda inf: inf.checkpoint_path) if len(lst) > 0: return lst[-1].checkpoint_path # try to find any checkpoint.pt in dir for path in glob.glob('%s/**/checkpoint.pt' % save_dir, recursive=True): return path # failure LoggerManager().get_logger().info( 'Can not find a uninas checkpoint (history) in: %s' % save_dir) return ''
def load_network(cls, file_path: str, network: nn.Module, num_replacements=1) -> bool: """ load network checkpoint from method checkpoint file replace parts of the param names to match the requirements """ checkpoint = cls.load_last_checkpoint(file_path) if len(checkpoint) > 0: state_dict, net_state_dict = checkpoint.get( 'state_dict', checkpoint), {} # map state dict keys accordingly key_mappings = {'net.': ''} for key0, v in state_dict.items(): key1 = key0 for k0, k1 in key_mappings.items(): key1 = key1.replace(k0, k1, num_replacements) net_state_dict[key1] = v network.load_state_dict(net_state_dict, strict=True) LoggerManager().get_logger().info('Loaded weights from file: %s' % file_path) return True else: return False
def _build(self, s_in: Shape, s_out: Shape) -> ShapeList: """ build the network, count params, log, maybe load pretrained weights """ s_in_net = s_in.copy(copy_id=True) super()._build(s_in, s_out) rows = [('cell index', 'input shapes', 'output shapes', '#params'), ('stem', s_in.str(), self.get_stem_output_shape(), count_parameters(self.get_stem()))] LoggerManager().get_logger().info('%s (%s):' % (self.__class__.__name__, self.model_name)) for i, (s_in, s_out, cell) in enumerate(zip(self.get_cell_input_shapes(flatten=False), self.get_cell_output_shapes(flatten=False), self.get_cells())): rows.append((i, s_in.str(), s_out.str(), count_parameters(cell))) rows.append(('head(s)', self.get_heads_input_shapes(), self.get_network_output_shapes(flatten=False), count_parameters(self.get_heads()))) rows.append(("complete network", s_in_net.str(), self.get_network_output_shapes(flatten=False), count_parameters(self))) log_in_columns(LoggerManager().get_logger(), rows, start_space=4) return self.get_network_output_shapes(flatten=False)
def __init__(self, args: Namespace, wildcards: dict, descriptions: dict = None): super().__init__() # args, seed self.args = args self.save_dir = self._parsed_argument('save_dir', args) self.is_test_run = self._parsed_argument('is_test_run', args) self.seed = self._parsed_argument('seed', args) self.is_deterministic = self._parsed_argument('is_deterministic', args) random.seed(self.seed) np.random.seed(self.seed) torch.manual_seed(self.seed) if self.is_deterministic: # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") torch.set_deterministic(self.is_deterministic) # maybe delete old dir, note arguments, save run_config if self._parsed_argument('save_del_old', args): shutil.rmtree(self.save_dir, ignore_errors=True) os.makedirs(self.save_dir, exist_ok=True) save_as_json(args, get_task_config_path(self.save_dir), wildcards) dump_system_info(self.save_dir + 'sysinfo.txt') # logging self.log_file = '%slog_task.txt' % self.save_dir LoggerManager().set_logging(default_save_file=self.log_file) self.logger = self.new_logger(index=None) log_args(self.logger, None, self.args, add_git_hash=True, descriptions=descriptions) Register.log_all(self.logger) # reset weight strategies so that consecutive tasks do not conflict with each other StrategyManager().reset() self.methods = []
def run(self): try: path = replace_standard_paths('{path_tmp}/tmp.run_config') self.interactive.to_json(path) Main.new_task(path).run() except Exception as e: LoggerManager().get_logger().error(str(e), exc_info=e) tkm.showwarning(message=str(e))
def load(cls, file_path: str, pl_module: AbstractMethod = None) -> dict: """ load method checkpoint from method checkpoint file and return it """ file_path = replace_standard_paths(file_path) if os.path.isfile(file_path): LoggerManager().get_logger().info('Found checkpoint: %s' % file_path) checkpoint = torch.load(file_path) if pl_module is not None: pl_module.load_state_dict(checkpoint['state_dict']) pl_module.on_load_checkpoint(checkpoint) LoggerManager().get_logger().info( 'Loaded weights from file: %s' % file_path) return checkpoint else: LoggerManager().get_logger().info( 'Can not load weights, does not exist / not a file: %s' % file_path) return {}
def profile_macs(self) -> np.int64: """ profile the macs for a single forward pass on a single data point """ macs = -1 try: macs = self.net.profile_macs() except Exception as e: LoggerManager().get_logger().error( "Failed profiling macs:\n%s\n..." % str(e)[:500]) return macs
def explore(mini: MiniNASTabularBenchmark, logger: Logger = None, n=-1, sort_by='acc1', maximize=True): if logger is None: logger = LoggerManager().get_logger() log_headline(logger, "highest acc1 topologies (%s, %s, %s)" % (mini.get_name(), mini.get_default_data_set(), mini.get_default_result_type())) rows = [("%s rank" % sort_by, "acc1", "loss", "params", "flops", "latency", "tuple")] for i, r in enumerate(mini.get_all_sorted([sort_by], [maximize])): if i >= n > 0: break rows.append((i, r.get_acc1(), r.get_loss(), r.get_params(), r.get_flops(), r.get_latency(), r.arch_tuple)) log_in_columns(logger, rows)
def __init__(self, profile_fun: AbstractProfileFunction = None, is_test_run=False, **__): super().__init__() self.data = dict(meta=dict(cls=self.__class__.__name__)) if profile_fun is not None: self.set_all(profile_cls=profile_fun.__class__.__name__, is_test_run=is_test_run) self.profile_fun = profile_fun self.logger = LoggerManager().get_logger()
def get_logger(cls, logger=None, is_test_run=False, save_dir='/tmp/', suffix=''): """ new logger if required """ if logger is not None: return logger return LoggerManager().get_logger( default_level=logging.DEBUG if is_test_run else logging.INFO, save_file='%slog_trainer%s.txt' % (save_dir, suffix))
def get_weights_criterion(self) -> (list, AbstractCriterion): weights = self.net.get_head_weightings() cls_criterion = self._parsed_meta_argument(Register.criteria, 'cls_criterion', self.hparams, index=None) criterion = cls_criterion(weights, self.hparams, self.data_set) if len(weights) > 1: LoggerManager().get_logger().info("Weighting model heads: %s" % str(weights)) return weights, criterion
def __init__(self, hparams: Namespace): super().__init__(hparams) self.update_architecture_weights = True self.train_loader = None # mask for idx in split(self._parsed_argument('mask_indices', self.hparams), cast_fun=int): self.strategy_manager.mask_index(idx) LoggerManager().get_logger().info( "Globally masking arc choices with index %d" % idx)
def _cb_add_str(self, *_, **__): if self._config_enabled: try: self.interactive.add_meta_value(self.node.name, self.meta.argument.name, self.var_add_str.get()) except Exception as e: LoggerManager().get_logger().error(str(e), exc_info=e) tkm.showwarning(message=str(e)) self.var_add_str.set(self.str_add) self.update_content() self.update()
def from_file(cls, file_path: str): """ create and load a profiler from a profiler save file """ file_path = replace_standard_paths(file_path) assert os.path.isfile( file_path), "File does not exist: %s" % str(file_path) cls_name = torch.load(file_path).get('meta').get('cls') profiler = Register.profilers.get(cls_name)() profiler.load(file_path) if profiler.get('is_test_run'): LoggerManager().get_logger().warning( "Loading profiler data from a file created in a test run!") return profiler
def run(self) -> 'AbstractTask': """ execute the task """ try: self._run() for method in self.methods: method.flush_logging() self.logger.info("Done!") return self except Exception as e: raise e finally: LoggerManager().cleanup()
def add_cells_from_config(self, config: dict): """ add all cell types in the given config """ for name, cfg in config.get('kwargs').get('cell_configs').items(): already_had = False if name in self._cell_partials.keys(): already_had = True self._cell_partials.pop(name) if name in self.cell_configs.keys(): already_had = True if already_had: LoggerManager().get_logger().info('%s cell type "%s" from given config' % ('Replaced' if already_had else 'Added', name)) self.cell_configs[name] = cfg
def save_config(cls, config: dict, config_dir: str, config_name: str) -> str: if config is None: return '' os.makedirs(config_dir, exist_ok=True) path = '%s/%s%s' % (replace_standard_paths(config_dir), config_name, cls.extension_net_config) path = os.path.abspath(path) with open(path, 'w+') as outfile: json.dump(config, outfile, ensure_ascii=False, indent=2) LoggerManager().get_logger().info('Wrote net config to %s' % path) return path
def _cb_rem_str(self, *_, **__): if self._config_enabled: splits = self.var_rem_str.get().split('#') idx = 0 if len(splits) < 2 else int(splits[1]) try: self.interactive.remove_meta_index(self.node.name, self.meta.argument.name, idx) except Exception as e: LoggerManager().get_logger().error(str(e), exc_info=e) tkm.showwarning(message=str(e)) self.update_content() self.update()
def save_as_config(self, path=None): if path is None: path = filedialog.asksaveasfilename(initialdir=self._init_dir, initialfile=name_task_config, title="Select file", filetypes=(("run config files", "*.run_config"), )) if isinstance(path, str) and not os.path.isdir(path) and len(path) > 0: try: self.interactive.to_json(path) except Exception as e: LoggerManager().get_logger().error(str(e), exc_info=e) tkm.showwarning(message=str(e)) self._init_dir = os.path.dirname(path) self._save_path = path
def _build2(self, s_in: Shape, s_out: Shape) -> ShapeList: """ build the network """ # find the search config if not os.path.isfile(self.search_config_path): self.search_config_path = Builder.find_net_config_path( self.search_config_path, pattern='search') # create a temporary search strategy tmp_s = RandomChoiceStrategy(max_epochs=1, name='__tmp__') sm = StrategyManager() assert len(sm.get_strategies_list( )) == 0, "can not load when there already is a search network" sm.add_strategy(tmp_s) sm.set_fixed_strategy_name('__tmp__') # create a search network search_net = Register.builder.load_from_config(self.search_config_path) assert isinstance(search_net, SearchUninasNetwork) search_net.build(s_in, s_out) search_net.set_forward_strategy(False) # set the architecture, get the config req_gene = "" if self.gene == 'random': search_net.forward_strategy() gene = sm.get_all_finalized_indices(unique=True, flat=True) self.model_name = "random(%s)" % str(gene) req_gene = " (%s)" % self.gene else: gene = split(self.gene, int) l0, l1 = len(sm.get_all_finalized_indices(unique=True)), len(gene) assert l0 == l1, "number of unique choices in the network (%d) must match length of the gene (%d)" % ( l0, l1) search_net.forward_strategy(fixed_arc=gene) config = search_net.config(finalize=True) # clean up sm.delete_strategy('__tmp__') del sm del search_net # build the actually used finalized network LoggerManager().get_logger().info( "Extracting architecture %s%s from the super-network" % (gene, req_gene)) self.net = Register.builder.from_config(config) return self.net.build(s_in, s_out)
def save(cls, file_path: str, pl_module: AbstractMethod, update_dict: dict = None) -> dict: """ save method checkpoint to file, not tracking it """ os.makedirs(os.path.dirname(file_path), exist_ok=True) checkpoint = dict(state_dict=pl_module.state_dict()) if isinstance(update_dict, dict): checkpoint.update(update_dict) pl_module.on_save_checkpoint(checkpoint) cls.atomic_save(file_path, checkpoint) LoggerManager().get_logger().info('Saved weights to file: %s' % file_path) return checkpoint
def load_add_config(self, add=True): path = filedialog.askopenfilename(initialdir=self._init_dir, initialfile=name_task_config, title="Select file", filetypes=(("run config files", "*.run_config"), )) if isinstance(path, str) and os.path.isfile(path) and len(path) > 0: self.enable_config_changes(False) if not add: self.reset() try: self.interactive.from_json(path) except Exception as e: LoggerManager().get_logger().error(str(e), exc_info=e) tkm.showwarning(message=str(e)) self.update_content() self.update()
def maybe_download(path_or_url: str, file_type: FileType = FileType.MISC) -> Union[str, None]: """ if the file does not locally exist at the given path, try to use a cached download, otherwise download it, then return the path """ path_or_url = replace_standard_paths(path_or_url) if os.path.isfile(path_or_url): return path_or_url else: try: os.makedirs(file_type.value, exist_ok=True) file_path = '%s/%s' % (file_type.value, path_or_url.split('/')[-1]) if not os.path.isfile(file_path): urlretrieve(path_or_url, file_path) LoggerManager().get_logger().info("downloaded %s to %s" % (path_or_url, file_path)) return file_path except: return None
def explore(mini: MiniNATSBenchTabularBenchmark): logger = LoggerManager().get_logger() # some stats of specific results logger.info( mini.get_by_arch_tuple((4, 3, 2, 1, 0, 2)).get_info_str('cifar10')) logger.info("") mini.get_by_arch_tuple((1, 2, 1, 2, 3, 4)).print(logger.info) logger.info("") mini.get_by_index(1554).print(logger.info) logger.info("") mini.get_by_arch_str( '|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|' ).print(logger.info) logger.info("") # best results by acc1 rows = [("acc1", "params", "arch tuple", "arch str")] log_headline( logger, "highest acc1 topologies (%s, %s, %s)" % (mini.get_name(), mini.get_default_data_set(), mini.get_default_result_type())) for i, r in enumerate(mini.get_all_sorted(['acc1'], [True])): rows.append( (r.get_acc1(), r.get_params(), str(r.arch_tuple), r.arch_str)) if i > 8: break log_in_columns(logger, rows) logger.info("") # best results by acc1 rows = [("acc1", "arch tuple", "arch str")] c = 0 log_headline( logger, "highest acc1 topologies without skip (%s, %s, %s)" % (mini.get_name(), mini.get_default_data_set(), mini.get_default_result_type())) for i, r in enumerate(mini.get_all_sorted(['acc1'], [True])): if 1 not in r.arch_tuple: rows.append((r.get_acc1(), str(r.arch_tuple), r.arch_str)) c += 1 if c > 9: break log_in_columns(logger, rows)
def initialize_weights(self, net: AbstractModule): logger = LoggerManager().get_logger() logger.info('Initializing: %s' % self.__class__.__name__) self._initialize_weights(net, logger)
def __init__(self, data_dir: str, save_dir: Union[str, None], bs_train: int, bs_test: int, train_transforms: transforms.Compose, test_transforms: transforms.Compose, train_batch_aug: Union[BatchAugmentations, None], test_batch_aug: Union[BatchAugmentations, None], num_workers: int, num_prefetch: int, valid_split: Union[int, float], valid_shuffle: bool, fake: bool, download: bool, **additional_args): """ :param data_dir: where to find (or download) the data set :param save_dir: global save dir, can store and reuse the info which data was used in the random valid split :param bs_train: batch size for the train loader :param bs_test: batch size for the test loader, <= 0 to have the same as bs_train :param train_transforms: train augmentations (on each data point individually) :param test_transforms: test augmentations (on each data point individually) :param train_batch_aug: train augmentations (across the entire batch) :param test_batch_aug: test augmentations (across the entire batch) :param num_workers: number of workers prefetching data :param num_prefetch: number of batches prefetched by every worker :param valid_split: absolute number of data points if int or >1, otherwise a fraction of the training set :param valid_shuffle: whether to shuffle validation data :param fake: use fake data instead (no need to provide either real data or enabling downloading) :param download: whether downloading is allowed :param additional_args: arguments that are added and used by child classes """ super().__init__() logger = LoggerManager().get_logger() self.dir = data_dir self.bs_train = bs_train self.bs_test = bs_test if bs_test > 0 else self.bs_train self.num_workers, self.num_prefetch = num_workers, num_prefetch self.valid_shuffle = valid_shuffle self.additional_args = additional_args self.fake = fake self.download = download and not self.fake if self.download and (not self.can_download): LoggerManager().get_logger().warning("The dataset can not be downloaded, but may be asked to.") self.train_transforms = train_transforms self.test_transforms = test_transforms self.train_batch_augmentations = train_batch_aug self.test_batch_augmentations = test_batch_aug # load/create meta info dict if isinstance(save_dir, str) and len(save_dir) > 0: meta_path = '%s/data.meta.pt' % replace_standard_paths(save_dir) if os.path.isfile(meta_path): meta = torch.load(meta_path) else: meta = defaultdict(dict) else: meta, meta_path = defaultdict(dict), None # give subclasses a good spot to react to additional arguments self._before_loading() # data if self.fake: train_data = self._get_fake_train_data(self.train_transforms) self.test_data = self._get_fake_test_data(self.test_transforms) else: train_data = self._get_train_data(self.train_transforms) self.test_data = self._get_test_data(self.test_transforms) # split train into train+valid or using stand-alone valid set if valid_split > 0: s1 = int(valid_split) if valid_split >= 1 else int(len(train_data)*valid_split) if s1 >= len(train_data): logger.warning("Tried to set valid split larger than the training set size, setting to 0.5") s1 = len(train_data)//2 s0 = len(train_data) - s1 if meta['splits'].get((s0, s1), None) is None: meta['splits'][(s0, s1)] = torch.randperm(s0+s1).tolist() indices = meta['splits'][(s0, s1)] self.valid_data = torch.utils.data.Subset(train_data, np.array(indices[s0:]).astype(np.int)) train_data = torch.utils.data.Subset(train_data, np.array(indices[0:s0]).astype(np.int)) logger.info('Data Set: splitting training set, will use %s data points as validation set' % s1) if self.length[1] > 0: logger.info('Data Set: a dedicated validation set exists, but it will be replaced.') elif self.length[1] > 0: if self.fake: self.valid_data = self._get_fake_valid_data(self.test_transforms) else: self.valid_data = self._get_valid_data(self.test_transforms) logger.info('Data Set: using the dedicated validation set with test augmentations') else: self.valid_data = None logger.info('Data Set: not using a validation set at all.') self.train_data = train_data # shapes data, label = self.train_data[0] self.data_shape = Shape(list(data.shape)) # save meta info dict if meta_path is not None: torch.save(meta, meta_path)
def log(cls, msg: str): LoggerManager().get_logger().info('%s: %s' % (cls.__name__, msg))