def __init__(self, args: Namespace, wildcards: dict, descriptions: dict = None): super().__init__() # args, seed self.args = args self.save_dir = self._parsed_argument('save_dir', args) self.is_test_run = self._parsed_argument('is_test_run', args) self.seed = self._parsed_argument('seed', args) self.is_deterministic = self._parsed_argument('is_deterministic', args) random.seed(self.seed) np.random.seed(self.seed) torch.manual_seed(self.seed) if self.is_deterministic: # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8") torch.set_deterministic(self.is_deterministic) # maybe delete old dir, note arguments, save run_config if self._parsed_argument('save_del_old', args): shutil.rmtree(self.save_dir, ignore_errors=True) os.makedirs(self.save_dir, exist_ok=True) save_as_json(args, get_task_config_path(self.save_dir), wildcards) dump_system_info(self.save_dir + 'sysinfo.txt') # logging self.log_file = '%slog_task.txt' % self.save_dir LoggerManager().set_logging(default_save_file=self.log_file) self.logger = self.new_logger(index=None) log_args(self.logger, None, self.args, add_git_hash=True, descriptions=descriptions) Register.log_all(self.logger) # reset weight strategies so that consecutive tasks do not conflict with each other StrategyManager().reset() self.methods = []
def __init__(self, args: Namespace, index=None, trainer: SimpleTrainer = None, load_path: str = None, method: AbstractMethod = None, **kwargs): super().__init__(args, index=index, **kwargs) # ensure all required parameters are available if (method is None) and (trainer is not None): method = trainer.get_method() reg_kwargs = Register.get_my_kwargs(self.__class__) if reg_kwargs.get('requires_trainer'): assert isinstance( trainer, SimpleTrainer), "%s needs a trainer" % self.__class__.__name__ assert isinstance( load_path, str ), "%s needs a path to load weights" % self.__class__.__name__ if reg_kwargs.get('requires_method'): assert isinstance( method, AbstractMethod), "%s needs a method" % self.__class__.__name__ assert isinstance(method.get_network(), SearchUninasNetwork),\ "%s's method must use a search network" % self.__class__.__name__ self.method = method self.trainer = trainer self.load_path = load_path
def meta_args_to_add(cls) -> [MetaArgument]: """ list meta arguments to add to argparse for when this class is chosen, classes specified in meta arguments may have their own respective arguments """ kwargs = Register.get_my_kwargs(cls) aug_sets = Register.augmentation_sets.filter_match_all(on_images=kwargs.get('images')) return super().meta_args_to_add() + [ MetaArgument('cls_augmentations', aug_sets, help_name='data augmentation'), ]
def meta_args_to_add(cls) -> [MetaArgument]: """ list meta arguments to add to argparse for when this class is chosen, classes specified in meta arguments may have their own respective arguments """ kwargs = Register.get_my_kwargs(cls) methods = Register.methods.filter_match_all(search=kwargs.get('search')) return super().meta_args_to_add() + [ MetaArgument('cls_device', Register.devices_managers, help_name='device manager', allowed_num=1), MetaArgument('cls_trainer', Register.trainers, help_name='trainer', allowed_num=1), MetaArgument('cls_method', methods, help_name='method', allowed_num=1), ]
def meta_args_to_add(cls, num_optimizers=1, search=True) -> [MetaArgument]: """ list meta arguments to add to argparse for when this class is chosen, classes specified in meta arguments may have their own respective arguments """ kwargs = Register.get_my_kwargs(cls) metrics = Register.metrics.filter_match_all( distill=kwargs.get('distill')) criteria = Register.criteria.filter_match_all( distill=kwargs.get('distill')) networks = Register.networks.filter_match_all(search=search) return super().meta_args_to_add() + [ MetaArgument('cls_data', Register.data_sets, help_name='data set', allowed_num=1), MetaArgument( 'cls_network', networks, help_name='network', allowed_num=1), MetaArgument('cls_criterion', criteria, help_name='criterion', allowed_num=1), MetaArgument('cls_metrics', metrics, help_name='training metric'), MetaArgument('cls_initializers', Register.initializers, help_name='weight initializer'), MetaArgument('cls_regularizers', Register.regularizers, help_name='regularizer'), MetaArgument('cls_optimizers', Register.optimizers, help_name='optimizer', allow_duplicates=True, allowed_num=num_optimizers, use_index=True), MetaArgument('cls_schedulers', Register.schedulers, help_name='scheduler', allow_duplicates=True, allowed_num=(0, num_optimizers), use_index=True), ]
def maybe_add_cls_tooltip(name: str, label: tk.Label = None, tooltip: CreateToolTip = None): cls_name = name.split('#')[0] try: cls = Register.get(cls_name) if cls is not None: text = '' if cls.__doc__ is not None and len(cls.__doc__) > 0: for i, line in enumerate(cls.__doc__.split('\n')): if i == len(line) == 0: continue text += line.replace(' ', '', 1) + '\n' text += '\n\n' text += '(implemented in: %s)' % get_class_path(cls) if tooltip is None: CreateToolTip(label, text=text, wraplength=sizes.get('wrap_tooltip')) else: tooltip.text = text except: pass
def calculate(cls, data0: list, data1: list) -> float: """ calculate and return the correlation value """ r, p = spearmanr(data0, data1) return r if __name__ == '__main__': import random scc = SpearmanCorrelation(column_names=('predicted accuracy', 'true accuracy'), add_lines=True, can_show=True) x = [v / 10 for v in range(-10, 10, 1)] y1 = [xi**3 for xi in x] y2 = [xi**4 for xi in x] y3 = [0 for xi in x] x2 = [(v + random.random()) * 0.01 for v in range(-10, 10, 1)] y21 = [xi + (random.random() - 0.5) for xi in x] scc.add_data(x, y1, 'data #1', other_metrics=(PearsonCorrelation, )) scc.add_data(x, y2, 'data #2', other_metrics=(PearsonCorrelation, )) scc.add_data(x, y3, 'data #3', other_metrics=(PearsonCorrelation, )) scc.add_data(x2, y21, 'data #4', other_metrics=(PearsonCorrelation, )) scc.plot(title=scc.__class__.__name__, legend=True, show=True, save_path=None) except ImportError as e: Register.missing_import(e)
def num_classes(cls) -> int: kwargs = Register.get_my_kwargs(cls) assert kwargs.get('classification', False) return cls.label_shape.num_features()
def is_classification(cls) -> bool: kwargs = Register.get_my_kwargs(cls) return kwargs.get('classification', False)
def is_on_images(cls) -> bool: kwargs = Register.get_my_kwargs(cls) return kwargs.get('images', False)
def is_single_path(cls) -> bool: return Register.get_my_kwargs(cls).get('single_path', False)
def is_external(self) -> bool: return Register.get_my_kwargs(self.__class__).get('external')
def is_tabular(self) -> bool: return Register.get_my_kwargs(self.__class__).get('tabular', False)