Esempio n. 1
0
    def find_net_config_path(cls,
                             config_name_or_dir: str,
                             pattern: str = '') -> str:
        """
        find a standard network config, get its full path, try:
            1) look up if config_name_or_dir matches a name (e.g. DARTS) in the common configs
            2) otherwise assume it is a dir, try to return a network config which name contains the pattern

        :param config_name_or_dir: dir or full path of a config
        :param pattern: if a config has to be searched in the dir 
        """
        # already done?
        if '/' in config_name_or_dir and config_name_or_dir.endswith(
                cls.extension_net_config):
            return config_name_or_dir
        # search in the common configs for a given name
        p = replace_standard_paths(
            '{path_conf_net}'
        ) + '/**/' + config_name_or_dir + '*' + cls.extension_net_config
        paths = glob(p, recursive=True)
        if len(paths) == 1:
            return paths[0]
        # find any unique network config in the given path
        p = '%s/**/*%s*%s' % (replace_standard_paths(config_name_or_dir),
                              pattern, cls.extension_net_config)
        paths = glob(p, recursive=True)
        if len(paths) == 1:
            return paths[0]
        raise FileNotFoundError(
            'can not find (unique) network config with given name/path "%s"' %
            config_name_or_dir)
Esempio n. 2
0
def assert_stats_match(name,
                       task_cfg,
                       cfg: dict,
                       num_params=None,
                       num_macs=None):
    cfg_dir = replace_standard_paths('{path_tmp}/tests/cfgs/')
    cfg_path = Builder.save_config(cfg, cfg_dir, name)
    exp = Main.new_task(
        task_cfg,
        args_changes={
            '{cls_data}.fake': True,
            '{cls_data}.batch_size_train': 2,
            '{cls_data}.batch_size_test': -1,
            '{cls_task}.is_test_run': True,
            '{cls_task}.save_dir': '{path_tmp}/tests/workdir/',
            "{cls_network}.config_path": cfg_path,
            "{cls_trainer}.ema_decay": -1,
            'cls_network_heads':
            'ClassificationHead',  # necessary for the DARTS search space to disable the aux heads
        },
        raise_unparsed=False)
    net = exp.get_method().get_network()
    macs = exp.get_method().profile_macs()
    net.eval()
    # print(net)
    cp = count_parameters(net)
    if num_params is not None:
        assert cp == num_params, 'Got unexpected num params for %s: %d, expected %d, diff: %d'\
                                 % (name, cp, num_params, abs(cp - num_params))
    if num_macs is not None:
        assert macs == num_macs, 'Got unexpected num macs for %s: %d, expected %d, diff: %d'\
                                 % (name, macs, num_macs, abs(macs - num_macs))
Esempio n. 3
0
    def __init__(self, args: Namespace, *args_, **kwargs):
        super().__init__(args, *args_, **kwargs)

        # args
        self.s1_path = replace_standard_paths(
            self._parsed_argument('s1_path', args))
        self.reset_bn = self._parsed_argument('reset_bn', args)

        # files
        self.tmp_load_path = '%s/checkpoints/checkpoint.tmp.pt' % self.save_dir
        os.makedirs(os.path.dirname(self.tmp_load_path), exist_ok=True)
        shutil.copyfile('%s/data.meta.pt' % self.s1_path,
                        '%s/data.meta.pt' % self.save_dir)

        # one method, one trainer... could be executed in parallel in future?
        log_headline(self.logger, 'setting up...')
        self.add_method()
        self.add_trainer(method=self.get_method(),
                         save_dir=self.save_dir,
                         num_devices=-1)
        self.log_detailed()
        self.get_method().get_network().set_forward_strategy(False)

        # algorithms
        estimator_kwargs = dict(trainer=self.trainer[0],
                                load_path=self.tmp_load_path)
        self.algorithm, self.estimators, self.termination = PymooHPOUtils.prepare(
            self, self.logger, estimator_kwargs, args)
Esempio n. 4
0
def visualize_config(config: dict, save_path: str):
    save_path = replace_standard_paths(save_path)
    cfg_path = Builder.save_config(config, replace_standard_paths('{path_tmp}/viz/'), 'viz')
    exp = Main.new_task(run_config, args_changes={
        '{cls_data}.fake': True,
        '{cls_data}.batch_size_train': 4,
        '{cls_task}.is_test_run': True,
        '{cls_task}.save_dir': '{path_tmp}/viz/task/',
        '{cls_task}.save_del_old': True,
        "{cls_network}.config_path": cfg_path,
    })
    net = exp.get_method().get_network()
    vt = VizTree(net)
    vt.print()
    vt.plot(save_path + 'net', add_subgraphs=True)
    print('Saved cell viz to %s' % save_path)
Esempio n. 5
0
 def find_last_checkpoint_path(cls,
                               save_dir: str,
                               index=0,
                               try_general_checkpoint=True) -> str:
     """
     attempt finding the checkpoint path in a dir,
     if 'save_dir' is a file, return it
     if there is a general checkpoint and 'try_general_checkpoint', return its path
     otherwise try finding the most recent checkpoint of the CheckpointCallback with index 'index'
     """
     save_dir = replace_standard_paths(save_dir)
     # try as path and general path
     if os.path.isfile(save_dir):
         return save_dir
     if try_general_checkpoint and os.path.isfile(
             cls._general_checkpoint_file(save_dir)):
         return cls._general_checkpoint_file(save_dir)
     # try by index
     lst = sorted(cls.list_infos(save_dir, index),
                  key=lambda inf: inf.checkpoint_path)
     if len(lst) > 0:
         return lst[-1].checkpoint_path
     # try to find any checkpoint.pt in dir
     for path in glob.glob('%s/**/checkpoint.pt' % save_dir,
                           recursive=True):
         return path
     # failure
     LoggerManager().get_logger().info(
         'Can not find a uninas checkpoint (history) in: %s' % save_dir)
     return ''
Esempio n. 6
0
 def find_pretrained_weights_path(cls,
                                  path: str,
                                  name: str = 'pretrained',
                                  raise_missing=True) -> str:
     """
     attempt finding pretrained weights in a dir,
     no matter if checkpoint or external
     """
     maybe_path = maybe_download(path, FileType.WEIGHTS)
     if isinstance(maybe_path, str):
         return maybe_path
     path = replace_standard_paths(path)
     if len(path) == 0 or os.path.isfile(path):
         return path
     # try looking for a checkpoint
     p = cls.find_last_checkpoint_path(path)
     if os.path.isfile(p):
         return p
     # glob any .pt/.pth file with the network name in it
     glob_path = '%s/**/*%s*.pt*' % (path, name)
     paths = glob.glob(glob_path, recursive=True)
     if len(paths) > 0:
         return paths[0]
     # failure
     if raise_missing:
         raise FileNotFoundError(
             'can not find any pretrained weights in "%s" or "%s"' %
             (path, glob_path))
     return ''
Esempio n. 7
0
 def run(self):
     try:
         path = replace_standard_paths('{path_tmp}/tmp.run_config')
         self.interactive.to_json(path)
         Main.new_task(path).run()
     except Exception as e:
         LoggerManager().get_logger().error(str(e), exc_info=e)
         tkm.showwarning(message=str(e))
 def make_nats(path_full: str,
               path_save: str = None) -> MiniNATSBenchTabularBenchmark:
     api = create(replace_standard_paths(path_full),
                  'tss',
                  fast_mode=True,
                  verbose=True)
     mini = MiniNATSBenchTabularBenchmark.make_from_full_api(api)
     if isinstance(path_save, str):
         mini.save(path_save)
     return mini
Esempio n. 9
0
def visualize_config(config: dict, save_path: str):
    save_path = replace_standard_paths(save_path)
    cfg_path = Builder.save_config(config, replace_standard_paths('{path_tmp}/viz/'), 'viz')
    exp = Main.new_task(run_config, args_changes={
        '{cls_data}.fake': True,
        '{cls_data}.batch_size_train': 2,
        '{cls_task}.is_test_run': True,
        '{cls_task}.save_dir': '{path_tmp}/viz/task/',
        '{cls_task}.save_del_old': True,
        "{cls_task}.note": "viz",
        "{cls_network}.config_path": cfg_path,
    })
    net = exp.get_method().get_network()
    for s in ['n', 'r']:
        for cell in net.get_cells():
            if cell.name.startswith(s):
                visualize_cell(cell, save_path, s)
                break
    print('Saved cell viz to %s' % save_path)
Esempio n. 10
0
 def load_config(cls, config_file_path: str) -> dict:
     """ load a json config file from given path """
     config_file_path = replace_standard_paths(config_file_path)
     if os.path.isfile(config_file_path):
         with open(config_file_path, 'r') as file:
             return json.load(file)
     else:
         raise FileNotFoundError(
             'Could not find file "%s" / "%s"' %
             (config_file_path, os.path.abspath(config_file_path)))
Esempio n. 11
0
def arg_list_from_json(paths: str) -> [str]:
    args = []
    for path in split(paths):
        path = replace_standard_paths(path)
        print('using config file: %s' % path)
        with open(path) as config_file:
            config = json.load(config_file)
            for k, v in config.items():
                args.append('--%s=%s' % (k, v))
    return args
Esempio n. 12
0
def visualize_genotype(wrapper: NetWrapper, save_dir: str):
    config_name = get_var_name(wrapper)
    save_dir = replace_standard_paths('%s%s/' % (save_dir, config_name))
    wrapper_net, config, _ = wrapper.generate(save_dir, 'viz')
    assert isinstance(wrapper_net, SearchUninasNetwork)

    g = Digraph(format='pdf', engine='dot',
                edge_attr=dict(fontsize='20', fontname="times"),
                node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5',
                               penwidth='2', fontname="times"))
    cell_order = config.get('kwargs').get('cell_order')
    stem_name = wrapper_net.get_network().get_stem().__class__.__name__

    g.node('stem', label=short_name.get(stem_name, stem_name), width=width_str(expansion='3'),
           fillcolor=colors.get('misc'))
    node_names = ['stem']

    for i, cell in enumerate(wrapper_net.get_network().get_cells()):
        assert isinstance(cell, SingleLayerCell)
        name = cell_order[i]
        op_cfg = config.get('kwargs').get('cell_configs').get(name).get('submodules').get('op')
        cell_cls = op_cfg.get('name')
        op_kwargs = op_cfg.get('kwargs')

        e = op_kwargs.get('expansion')
        k = op_kwargs.get('k_size')
        s_in = cell.cached.get('shape_in')[0]
        s_out = cell.cached.get('shape_out')[0]

        label = '{name} E{e} K{k}'.format(**{
            'name': short_name.get(cell_cls, cell_cls),
            'e': e,
            'k': k,
        })
        g.node(name, label=label, width=width_str(expansion=e), fillcolor=colors.get(k))
        node_names.append(name)
        if len(node_names) > 1:
            g.edge(node_names[-2], node_names[-1], label='\t'+'*'.join([str(s) for s in s_in.shape]))
        print('{:<10}{:<30}{:<30}{:<30}{}'.format(cell.name, cell_cls, s_in.str(), s_out.str(), str(op_kwargs)))

    head = wrapper_net.get_network().get_heads()[-1]
    assert isinstance(head, FeatureMixClassificationHead)

    g.node('fmix', label='Conv K1', width=width_str(expansion='3'), fillcolor=colors.get('misc'))
    node_names.append('fmix')
    s_in = head.cached.get('shape_in')
    g.edge(node_names[-2], node_names[-1], label='\t'+'*'.join([str(s) for s in s_in.shape]))

    g.node('head', label='classification', width=width_str(expansion='3'), fillcolor=colors.get('misc'))
    node_names.append('head')
    s_in = head.cached.get('shape_inner')
    g.edge(node_names[-2], node_names[-1], label='\t'+'*'.join([str(s) for s in s_in.shape]))

    g.view(filename='%snet' % save_dir)
    print('Saved cell viz to %s' % save_dir)
Esempio n. 13
0
 def save_config(cls, config: dict, config_dir: str,
                 config_name: str) -> str:
     if config is None:
         return ''
     os.makedirs(config_dir, exist_ok=True)
     path = '%s/%s%s' % (replace_standard_paths(config_dir), config_name,
                         cls.extension_net_config)
     path = os.path.abspath(path)
     with open(path, 'w+') as outfile:
         json.dump(config, outfile, ensure_ascii=False, indent=2)
     LoggerManager().get_logger().info('Wrote net config to %s' % path)
     return path
Esempio n. 14
0
 def find_classes_in_config(self, config_file_path: str) -> {str: list}:
     """ load a json config file from given path and figure out the used classes in it """
     config_file_path = replace_standard_paths(config_file_path)
     cfg = self.load_config(config_file_path)
     names = self._rec_list_attr(cfg, 'name')
     return {
         'cls_network_body':
         [n for n in names if n.endswith('NetworkBody')],
         'cls_network_cells': [n for n in names if n.endswith('Cell')],
         'cls_network_stem': [n for n in names if n.endswith('Stem')],
         'cls_network_heads': [n for n in names if n.endswith('Head')],
     }
Esempio n. 15
0
 def from_file(cls, file_path: str):
     """ create and load a profiler from a profiler save file """
     file_path = replace_standard_paths(file_path)
     assert os.path.isfile(
         file_path), "File does not exist: %s" % str(file_path)
     cls_name = torch.load(file_path).get('meta').get('cls')
     profiler = Register.profilers.get(cls_name)()
     profiler.load(file_path)
     if profiler.get('is_test_run'):
         LoggerManager().get_logger().warning(
             "Loading profiler data from a file created in a test run!")
     return profiler
Esempio n. 16
0
    def from_args(cls, args: Namespace, index: int = None) -> 'AbstractDataSet':
        # parsed arguments, and the global save dir
        all_args = cls._all_parsed_arguments(args, index=index)

        data_dir = replace_standard_paths(all_args.pop('dir'))
        fake = all_args.pop('fake')
        download = all_args.pop('download') and not fake

        try:
            _, save_dir = find_in_args(args, '.save_dir')
            save_dir = replace_standard_paths(save_dir)
        except ValueError:
            save_dir = ""

        # augmentations per data point and batch, for training and test
        tr_d, tr_b, te_d, te_b = [], [], [], []
        for i, aug_set in enumerate(cls._parsed_meta_arguments(Register.augmentation_sets, 'cls_augmentations', args, index=index)):
            tr_d_, tr_b_ = aug_set.get_train_transforms(args, i, cls)
            te_d_, te_b_ = aug_set.get_test_transforms(args, i, cls)
            tr_d.extend(tr_d_)
            tr_b.extend(tr_b_)
            te_d.extend(te_d_)
            te_b.extend(te_b_)
        if cls.is_on_images():
            final_transforms = [transforms.ToTensor(), transforms.Normalize(cls.data_mean, cls.data_std)]
        else:
            final_transforms = []
        train_transforms = transforms.Compose(tr_d + final_transforms)
        test_transforms = transforms.Compose(te_d + final_transforms)
        train_batch_aug = BatchAugmentations(tr_b) if len(tr_b) > 0 else None
        test_batch_aug = BatchAugmentations(te_b) if len(te_b) > 0 else None

        return cls(data_dir=data_dir, save_dir=save_dir,
                   bs_train=all_args.pop('batch_size_train'), bs_test=all_args.pop('batch_size_test'),
                   train_transforms=train_transforms, test_transforms=test_transforms,
                   train_batch_aug=train_batch_aug, test_batch_aug=test_batch_aug,
                   num_workers=all_args.pop('num_workers'), num_prefetch=all_args.pop('num_prefetch'),
                   valid_split=all_args.pop('valid_split'), valid_shuffle=all_args.pop('valid_shuffle'),
                   fake=fake, download=download, **all_args)
Esempio n. 17
0
def get_imagenet(data_dir: str, num_workers=8, batch_size=8, aug_dict: dict = None) -> AbstractDataSet:
    data_kwargs = {
        "cls_data": "Imagenet1000Data",
        "Imagenet1000Data.fake": False,
        "Imagenet1000Data.dir": replace_standard_paths(data_dir),
        "Imagenet1000Data.num_workers": num_workers,
        "Imagenet1000Data.batch_size_train": batch_size,
        "Imagenet1000Data.batch_size_test": batch_size,

    }
    if aug_dict is None:
        aug_dict = {
            "cls_augmentations": "DartsImagenetAug",
            "DartsImagenetAug#0.crop_size": 224,
        }
    data_kwargs.update(aug_dict)
    return get_dataset(data_kwargs)
Esempio n. 18
0
 def load(cls, file_path: str, pl_module: AbstractMethod = None) -> dict:
     """ load method checkpoint from method checkpoint file and return it """
     file_path = replace_standard_paths(file_path)
     if os.path.isfile(file_path):
         LoggerManager().get_logger().info('Found checkpoint: %s' %
                                           file_path)
         checkpoint = torch.load(file_path)
         if pl_module is not None:
             pl_module.load_state_dict(checkpoint['state_dict'])
             pl_module.on_load_checkpoint(checkpoint)
             LoggerManager().get_logger().info(
                 'Loaded weights from file: %s' % file_path)
         return checkpoint
     else:
         LoggerManager().get_logger().info(
             'Can not load weights, does not exist / not a file: %s' %
             file_path)
         return {}
Esempio n. 19
0
    def test_examples(self):
        """
        run all examples in /experiments/examples/ in ascending name order
        """
        paths = sorted(
            glob.glob("%s/experiments/examples/*.py" %
                      replace_standard_paths("{path_project_dir}")))
        start_idx = 0  # just to get quickly to the failing one
        for i, path in enumerate(paths):
            if i < start_idx:
                continue
            if path.endswith('__init__.py'):
                continue
            if 'pbt' in path:
                continue

            if os.system("python3 %s" % path) > 0:
                assert False, "Failed running i=%d path=%s, got an error" % (
                    i, path)
Esempio n. 20
0
 def _save(self, save_path: str):
     if isinstance(save_path, str):
         save_path = replace_standard_paths(save_path)
         name = os.path.basename(save_path)
         path = os.path.dirname(save_path)
         os.makedirs(path, exist_ok=True)
         s = '%s/%s' % (path, name)
         data = dict(
             cls=self.__class__.__name__,
             default_data_set=self.default_data_set,
             default_result_type=self.default_result_type,
             value_space=self.get_value_space(),
             bench_name=self.bench_name,
             bench_description=self.bench_description,
             results={k: r.state_dict() for k, r in self.results.items()},
             arch_to_idx=self.arch_to_idx,
             tuple_to_str=self.tuple_to_str,
             tuple_to_idx=self.tuple_to_idx)
         torch.save(data, s)
Esempio n. 21
0
def save_as_json(args: args_type, file_path: str, wildcards: dict):
    """ save the given Namespace as ordered json file, replacing names with wildcards """
    file_path = replace_standard_paths(file_path)
    name_to_wildcard = {'%s.' % v: '{%s}.' % k for k, v in wildcards.items()}

    # generate a run_config from current args, replace with wildcards, sort in order of cls_* meta args
    config, config_sorted = {}, {}
    for k, v in items(args).items():
        for k2, v2 in name_to_wildcard.items():
            k = k.replace(k2, v2)
        config[k] = v
    for i, name in enumerate(config.keys()):
        if name.startswith('cls_'):
            for k, v in config.items():
                if k.startswith(name) or k.startswith('{%s' % name):
                    config_sorted[k] = v

    os.makedirs(os.path.dirname(file_path), exist_ok=True)
    with open(file_path, 'w+') as outfile:
        json.dump(config_sorted, outfile, indent=4)
    pass
Esempio n. 22
0
def main():
    parser = argparse.ArgumentParser(
        description=
        'uninas generate a network config from simple genotype description')
    parser.add_argument('--cells',
                        type=str,
                        default=None,
                        help='which config to generate, all available if None')
    args = parser.parse_args()
    args.save_dir = replace_standard_paths('{path_conf_net_originals}/')

    if args.cells is not None:
        all_cell_names = [args.cells]
    else:
        all_cell_names = []
        for key, value in list(globals().items()):
            if isinstance(value, Genotype):
                all_cell_names.append(key)

    for cell_name in all_cell_names:
        print('Name:\t\t%s' % cell_name)
        generate_from_name(cell_name)
Esempio n. 23
0
    def __init__(self, args: Namespace, *args_, **kwargs):
        AbstractNetTask.__init__(self, args, *args_, **kwargs)

        # args
        self.reset_bn = self._parsed_argument('reset_bn', args)
        self.s1_path = replace_standard_paths(
            self._parsed_argument('s1_path', args))

        # files
        self.tmp_load_path = '%s/checkpoint.tmp.pt' % self.save_dir
        os.makedirs(os.path.dirname(self.tmp_load_path), exist_ok=True)
        shutil.copyfile('%s/data.meta.pt' % self.s1_path,
                        '%s/data.meta.pt' % self.save_dir)

        # one method, one trainer... could be executed in parallel in future?
        log_headline(self.logger, 'setting up...')
        self.add_method()
        self.add_trainer(method=self.get_method(),
                         save_dir=self.save_dir,
                         num_devices=-1)
        self.log_detailed()
        self.get_method().get_network().set_forward_strategy(False)

        # algorithms
        estimator_kwargs = dict(trainer=self.trainer[0],
                                load_path=self.tmp_load_path)
        self.hpo, self.constraints, self.objectives = SelfHPOUtils.prepare(
            self, self.logger, estimator_kwargs, args)

        # arc space
        space = ValueSpace(*[
            DiscreteValues.interval(0, n)
            for n in self.get_method().strategy_manager.get_num_choices(
                unique=True)
        ])
        self._architecture_space = SelfHPOUtils.mask_architecture_space(
            self.args, space)
Esempio n. 24
0
    def __init__(self, data_dir: str, save_dir: Union[str, None],
                 bs_train: int, bs_test: int,
                 train_transforms: transforms.Compose, test_transforms: transforms.Compose,
                 train_batch_aug: Union[BatchAugmentations, None], test_batch_aug: Union[BatchAugmentations, None],
                 num_workers: int, num_prefetch: int,
                 valid_split: Union[int, float], valid_shuffle: bool,
                 fake: bool, download: bool,
                 **additional_args):
        """

        :param data_dir: where to find (or download) the data set
        :param save_dir: global save dir, can store and reuse the info which data was used in the random valid split
        :param bs_train: batch size for the train loader
        :param bs_test: batch size for the test loader, <= 0 to have the same as bs_train
        :param train_transforms: train augmentations (on each data point individually)
        :param test_transforms: test augmentations (on each data point individually)
        :param train_batch_aug: train augmentations (across the entire batch)
        :param test_batch_aug: test augmentations (across the entire batch)
        :param num_workers: number of workers prefetching data
        :param num_prefetch: number of batches prefetched by every worker
        :param valid_split: absolute number of data points if int or >1, otherwise a fraction of the training set
        :param valid_shuffle: whether to shuffle validation data
        :param fake: use fake data instead (no need to provide either real data or enabling downloading)
        :param download: whether downloading is allowed
        :param additional_args: arguments that are added and used by child classes
        """
        super().__init__()
        logger = LoggerManager().get_logger()
        self.dir = data_dir
        self.bs_train = bs_train
        self.bs_test = bs_test if bs_test > 0 else self.bs_train
        self.num_workers, self.num_prefetch = num_workers, num_prefetch
        self.valid_shuffle = valid_shuffle
        self.additional_args = additional_args

        self.fake = fake
        self.download = download and not self.fake
        if self.download and (not self.can_download):
            LoggerManager().get_logger().warning("The dataset can not be downloaded, but may be asked to.")

        self.train_transforms = train_transforms
        self.test_transforms = test_transforms
        self.train_batch_augmentations = train_batch_aug
        self.test_batch_augmentations = test_batch_aug

        # load/create meta info dict
        if isinstance(save_dir, str) and len(save_dir) > 0:
            meta_path = '%s/data.meta.pt' % replace_standard_paths(save_dir)
            if os.path.isfile(meta_path):
                meta = torch.load(meta_path)
            else:
                meta = defaultdict(dict)
        else:
            meta, meta_path = defaultdict(dict), None

        # give subclasses a good spot to react to additional arguments
        self._before_loading()

        # data
        if self.fake:
            train_data = self._get_fake_train_data(self.train_transforms)
            self.test_data = self._get_fake_test_data(self.test_transforms)
        else:
            train_data = self._get_train_data(self.train_transforms)
            self.test_data = self._get_test_data(self.test_transforms)

        # split train into train+valid or using stand-alone valid set
        if valid_split > 0:
            s1 = int(valid_split) if valid_split >= 1 else int(len(train_data)*valid_split)
            if s1 >= len(train_data):
                logger.warning("Tried to set valid split larger than the training set size, setting to 0.5")
                s1 = len(train_data)//2
            s0 = len(train_data) - s1
            if meta['splits'].get((s0, s1), None) is None:
                meta['splits'][(s0, s1)] = torch.randperm(s0+s1).tolist()
            indices = meta['splits'][(s0, s1)]
            self.valid_data = torch.utils.data.Subset(train_data, np.array(indices[s0:]).astype(np.int))
            train_data = torch.utils.data.Subset(train_data, np.array(indices[0:s0]).astype(np.int))
            logger.info('Data Set: splitting training set, will use %s data points as validation set' % s1)
            if self.length[1] > 0:
                logger.info('Data Set: a dedicated validation set exists, but it will be replaced.')
        elif self.length[1] > 0:
            if self.fake:
                self.valid_data = self._get_fake_valid_data(self.test_transforms)
            else:
                self.valid_data = self._get_valid_data(self.test_transforms)
            logger.info('Data Set: using the dedicated validation set with test augmentations')
        else:
            self.valid_data = None
            logger.info('Data Set: not using a validation set at all.')
        self.train_data = train_data

        # shapes
        data, label = self.train_data[0]
        self.data_shape = Shape(list(data.shape))

        # save meta info dict
        if meta_path is not None:
            torch.save(meta, meta_path)
Esempio n. 25
0
                metrics[k] = metrics.get(k, {data_set_name: -1})
            # result
            r = MiniResult(
                arch_index=i,
                arch_str="%s(%s)" % (space_name, ", ".join([str(v) for v in candidate.values])),
                arch_tuple=candidate.values,
                **metrics
            )

            assert tuple_to_str.get(r.arch_tuple) is None, "can not yet merge duplicate architecture results"
            results[i] = r
            arch_to_idx[r.arch_str] = i
            tuple_to_idx[r.arch_tuple] = i
            tuple_to_str[r.arch_tuple] = r.arch_str

        data_sets = list(results.get(0).params.keys())
        return MiniNASSearchTabularBenchmark(
            default_data_set=data_sets[0],
            default_result_type=default_result_type,
            bench_name="%s on %s" % (space_name, data_sets[0]),
            bench_description="super-network evaluation results",
            value_space=space, results=results, arch_to_idx=arch_to_idx,
            tuple_to_str=tuple_to_str, tuple_to_idx=tuple_to_idx)


if __name__ == '__main__':
    Builder()
    path_ = replace_standard_paths('{path_tmp}/s2_bench/bench-SearchUninasNetwork_on_Imagenet1000Data.pt')
    mini_ = MiniNASSearchTabularBenchmark.load(path_)
    explore(mini_)
Esempio n. 26
0
 def visualize(self):
     visualize_args_tree(self.interactive.root).view(
         filename="args_tree",
         directory=replace_standard_paths("{path_tmp}"),
         quiet_view=True,
         cleanup=True)
Esempio n. 27
0
from uninas.utils.paths import replace_standard_paths
from uninas.register import Register


def example_export_network(path: str) -> AbstractUninasNetwork:
    """ create a new network and export it, does not require to have onnx installed """
    network = get_network("FairNasC",
                          Shape([3, 224, 224]),
                          Shape([1000]),
                          weights_path=None)
    network = network.cuda()
    network.export_onnx(path, export_params=True)
    return network


try:
    import onnx

    if __name__ == '__main__':
        logger = LoggerManager().get_logger()
        export_path = replace_standard_paths("{path_tmp}/onnx/FairNasC.onnx")
        net1 = example_export_network(export_path)

        log_headline(logger, "onnx graph")
        net2 = onnx.load(export_path)
        onnx.checker.check_model(net2)
        logger.info(onnx.helper.printable_graph(net2.graph))

except ImportError as e:
    Register.missing_import(e)
Esempio n. 28
0
def visualize_args_tree(node: ArgsTreeNode):
    g = Digraph(format='pdf',
                engine='dot',
                edge_attr=dict(fontsize='20', fontname="times"),
                node_attr=dict(style='filled',
                               shape='rect',
                               align='center',
                               fontsize='20',
                               height='0.5',
                               penwidth='2',
                               fontname="times"))
    _visualize_args_tree(node, g)
    return g


if __name__ == '__main__':
    from uninas.builder import Builder
    Builder()

    args_list = arg_list_from_json("/tmp/uninas/s1/task.run_config")

    root = ArgsTreeNode(Main)
    root.build_from_args(args_list)
    print("-" * 200)
    visualize_args_tree(root).view(
        filename="args_tree",
        directory=replace_standard_paths("{path_tmp}"),
        cleanup=True,
        quiet_view=True)
Esempio n. 29
0
 def __init__(self, save_dir: str, index: int, **_):
     super().__init__()
     self._save_dir = replace_standard_paths(save_dir)
     assert isinstance(index, int)
     self._index = index
Esempio n. 30
0
    value_space = ValueSpace(
        *[DiscreteValues.interval(0, 5) for _ in range(6)])
    return MiniNASParsedTabularBenchmark.make_from_dirs(
        path, space_name, value_space)


def sample_architectures(mini: MiniNASParsedTabularBenchmark, n=10):
    """ sample some random architectures to train """
    for i in range(n):
        print(i, '\t\t', mini.get_value_space().random_sample())


if __name__ == '__main__':
    Builder()

    path_ = replace_standard_paths(
        '{path_data}/generated_bench/bench201_n5c16_1.pt')
    # mini_ = create_bench201("/mnt/tcml-master01/mnt/beegfs/home/laube/Data/experiments/git/uninas/full/test_s3_bench/Cifar100Data/bench201_n5c16/")

    # path_ = replace_standard_paths('{path_data}/generated_bench/bench201_n4c64_1.pt')
    # mini_ = create_bench201("/mnt/tcml-master01/mnt/beegfs/home/laube/Data/experiments/git/uninas/full/test_s3_bench/Cifar100Data/bench201_n4c64/")

    # path_ = replace_standard_paths('{path_data}/generated_bench/SIN_fairnas_mini_only2.pt')
    # mini_ = create_fairnas("/mnt/tcml-master01/mnt/beegfs/home/laube/Data/experiments/git/uninas/full/s3sin/SubImagenet100Data/fairnas/2/")

    # path_ = replace_standard_paths('{path_data}/generated_bench/SIN_fairnas_mini_all.pt')
    # mini_ = create_fairnas("/mnt/tcml-master01/mnt/beegfs/home/laube/Data/experiments/git/uninas/full/s3sin/SubImagenet100Data/fairnas/")

    mini_ = MiniNASParsedTabularBenchmark.load(path_)

    sample_architectures(mini_)