Esempio n. 1
0
    def __init__(self, args: Namespace, wildcards: dict, descriptions: dict = None):
        super().__init__()

        # args, seed
        self.args = args
        self.save_dir = self._parsed_argument('save_dir', args)
        self.is_test_run = self._parsed_argument('is_test_run', args)
        self.seed = self._parsed_argument('seed', args)
        self.is_deterministic = self._parsed_argument('is_deterministic', args)
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        if self.is_deterministic:
            # see https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
            os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
            torch.set_deterministic(self.is_deterministic)

        # maybe delete old dir, note arguments, save run_config
        if self._parsed_argument('save_del_old', args):
            shutil.rmtree(self.save_dir, ignore_errors=True)
        os.makedirs(self.save_dir, exist_ok=True)
        save_as_json(args, get_task_config_path(self.save_dir), wildcards)
        dump_system_info(self.save_dir + 'sysinfo.txt')

        # logging
        self.log_file = '%slog_task.txt' % self.save_dir
        LoggerManager().set_logging(default_save_file=self.log_file)
        self.logger = self.new_logger(index=None)
        log_args(self.logger, None, self.args, add_git_hash=True, descriptions=descriptions)
        Register.log_all(self.logger)

        # reset weight strategies so that consecutive tasks do not conflict with each other
        StrategyManager().reset()

        self.methods = []
Esempio n. 2
0
    def __init__(self,
                 args: Namespace,
                 index=None,
                 trainer: SimpleTrainer = None,
                 load_path: str = None,
                 method: AbstractMethod = None,
                 **kwargs):
        super().__init__(args, index=index, **kwargs)

        # ensure all required parameters are available
        if (method is None) and (trainer is not None):
            method = trainer.get_method()
        reg_kwargs = Register.get_my_kwargs(self.__class__)
        if reg_kwargs.get('requires_trainer'):
            assert isinstance(
                trainer,
                SimpleTrainer), "%s needs a trainer" % self.__class__.__name__
            assert isinstance(
                load_path, str
            ), "%s needs a path to load weights" % self.__class__.__name__
        if reg_kwargs.get('requires_method'):
            assert isinstance(
                method,
                AbstractMethod), "%s needs a method" % self.__class__.__name__
            assert isinstance(method.get_network(), SearchUninasNetwork),\
                "%s's method must use a search network" % self.__class__.__name__

        self.method = method
        self.trainer = trainer
        self.load_path = load_path
Esempio n. 3
0
 def meta_args_to_add(cls) -> [MetaArgument]:
     """
     list meta arguments to add to argparse for when this class is chosen,
     classes specified in meta arguments may have their own respective arguments
     """
     kwargs = Register.get_my_kwargs(cls)
     aug_sets = Register.augmentation_sets.filter_match_all(on_images=kwargs.get('images'))
     return super().meta_args_to_add() + [
         MetaArgument('cls_augmentations', aug_sets, help_name='data augmentation'),
     ]
Esempio n. 4
0
    def meta_args_to_add(cls) -> [MetaArgument]:
        """
        list meta arguments to add to argparse for when this class is chosen,
        classes specified in meta arguments may have their own respective arguments
        """
        kwargs = Register.get_my_kwargs(cls)
        methods = Register.methods.filter_match_all(search=kwargs.get('search'))

        return super().meta_args_to_add() + [
            MetaArgument('cls_device', Register.devices_managers, help_name='device manager', allowed_num=1),
            MetaArgument('cls_trainer', Register.trainers, help_name='trainer', allowed_num=1),
            MetaArgument('cls_method', methods, help_name='method', allowed_num=1),
        ]
Esempio n. 5
0
    def meta_args_to_add(cls, num_optimizers=1, search=True) -> [MetaArgument]:
        """
        list meta arguments to add to argparse for when this class is chosen,
        classes specified in meta arguments may have their own respective arguments
        """
        kwargs = Register.get_my_kwargs(cls)
        metrics = Register.metrics.filter_match_all(
            distill=kwargs.get('distill'))
        criteria = Register.criteria.filter_match_all(
            distill=kwargs.get('distill'))
        networks = Register.networks.filter_match_all(search=search)

        return super().meta_args_to_add() + [
            MetaArgument('cls_data',
                         Register.data_sets,
                         help_name='data set',
                         allowed_num=1),
            MetaArgument(
                'cls_network', networks, help_name='network', allowed_num=1),
            MetaArgument('cls_criterion',
                         criteria,
                         help_name='criterion',
                         allowed_num=1),
            MetaArgument('cls_metrics', metrics, help_name='training metric'),
            MetaArgument('cls_initializers',
                         Register.initializers,
                         help_name='weight initializer'),
            MetaArgument('cls_regularizers',
                         Register.regularizers,
                         help_name='regularizer'),
            MetaArgument('cls_optimizers',
                         Register.optimizers,
                         help_name='optimizer',
                         allow_duplicates=True,
                         allowed_num=num_optimizers,
                         use_index=True),
            MetaArgument('cls_schedulers',
                         Register.schedulers,
                         help_name='scheduler',
                         allow_duplicates=True,
                         allowed_num=(0, num_optimizers),
                         use_index=True),
        ]
Esempio n. 6
0
def maybe_add_cls_tooltip(name: str,
                          label: tk.Label = None,
                          tooltip: CreateToolTip = None):
    cls_name = name.split('#')[0]
    try:
        cls = Register.get(cls_name)
        if cls is not None:
            text = ''
            if cls.__doc__ is not None and len(cls.__doc__) > 0:
                for i, line in enumerate(cls.__doc__.split('\n')):
                    if i == len(line) == 0:
                        continue
                    text += line.replace('    ', '', 1) + '\n'
                text += '\n\n'
            text += '(implemented in: %s)' % get_class_path(cls)
            if tooltip is None:
                CreateToolTip(label,
                              text=text,
                              wraplength=sizes.get('wrap_tooltip'))
            else:
                tooltip.text = text
    except:
        pass
Esempio n. 7
0
        def calculate(cls, data0: list, data1: list) -> float:
            """
            calculate and return the correlation value
            """
            r, p = spearmanr(data0, data1)
            return r

    if __name__ == '__main__':
        import random
        scc = SpearmanCorrelation(column_names=('predicted accuracy',
                                                'true accuracy'),
                                  add_lines=True,
                                  can_show=True)
        x = [v / 10 for v in range(-10, 10, 1)]
        y1 = [xi**3 for xi in x]
        y2 = [xi**4 for xi in x]
        y3 = [0 for xi in x]
        x2 = [(v + random.random()) * 0.01 for v in range(-10, 10, 1)]
        y21 = [xi + (random.random() - 0.5) for xi in x]
        scc.add_data(x, y1, 'data #1', other_metrics=(PearsonCorrelation, ))
        scc.add_data(x, y2, 'data #2', other_metrics=(PearsonCorrelation, ))
        scc.add_data(x, y3, 'data #3', other_metrics=(PearsonCorrelation, ))
        scc.add_data(x2, y21, 'data #4', other_metrics=(PearsonCorrelation, ))
        scc.plot(title=scc.__class__.__name__,
                 legend=True,
                 show=True,
                 save_path=None)

except ImportError as e:
    Register.missing_import(e)
Esempio n. 8
0
 def num_classes(cls) -> int:
     kwargs = Register.get_my_kwargs(cls)
     assert kwargs.get('classification', False)
     return cls.label_shape.num_features()
Esempio n. 9
0
 def is_classification(cls) -> bool:
     kwargs = Register.get_my_kwargs(cls)
     return kwargs.get('classification', False)
Esempio n. 10
0
 def is_on_images(cls) -> bool:
     kwargs = Register.get_my_kwargs(cls)
     return kwargs.get('images', False)
Esempio n. 11
0
 def is_single_path(cls) -> bool:
     return Register.get_my_kwargs(cls).get('single_path', False)
Esempio n. 12
0
 def is_external(self) -> bool:
     return Register.get_my_kwargs(self.__class__).get('external')
Esempio n. 13
0
 def is_tabular(self) -> bool:
     return Register.get_my_kwargs(self.__class__).get('tabular', False)