예제 #1
0
    def run(self):
        common_s2_prepare_run(self.logger, self.trainer, self.s1_path,
                              self.tmp_load_path, self.reset_bn, self.methods)
        checkpoint_dir = self.checkpoint_dir(self.save_dir)
        candidate_dir = '%s/candidates/' % checkpoint_dir

        # get problem, run
        xu = np.array([
            n - 1 for n in self.get_method().strategy_manager.get_num_choices(
                unique=True)
        ])
        xl = np.zeros_like(xu)
        problem = PymooProblem(estimators=self.estimators,
                               xl=xl,
                               xu=xu,
                               n_var=len(xu))
        wrapper = PymooHPOUtils.run(problem, self.algorithm, self.termination,
                                    self.seed, self.logger, checkpoint_dir)

        # save results
        for sr in wrapper.sorted_best():
            self.get_method().get_network().forward_strategy(
                fixed_arc=tuple(sr.x))
            Builder.save_config(
                self.get_method().get_network().config(finalize=True),
                candidate_dir,
                'candidate-%s' % '-'.join([str(xs) for xs in sr.x]))
예제 #2
0
    def _run(self, save=True):
        common_s2_prepare_run(self.logger, self.trainer, self.s1_path,
                              self.tmp_load_path, self.reset_bn, self.methods)
        checkpoint_dir = self.checkpoint_dir(self.save_dir)
        candidate_dir = '%s/candidates/' % checkpoint_dir
        file_viz = '%sx.pdf' % checkpoint_dir
        self.get_method().eval()

        # run
        algorithm = self.hpo.run_opt(hparams=self.args,
                                     logger=self.logger,
                                     checkpoint_dir=checkpoint_dir,
                                     value_space=self._architecture_space,
                                     constraints=self.constraints,
                                     objectives=self.objectives)
        population = algorithm.get_total_population(sort=True)

        # save results
        if save:
            population.plot(self.objectives[0].key,
                            self.objectives[1].key,
                            show=False,
                            save_path=file_viz)
            for candidate in population.fronts[0]:
                self.get_method().get_network().forward_strategy(
                    fixed_arc=candidate.values)
                Builder.save_config(
                    self.get_method().get_network().config(finalize=True),
                    candidate_dir, 'candidate-%s' %
                    '-'.join([str(g) for g in candidate.values]))
        return algorithm, population
예제 #3
0
    def extend_args(cls, args_list: [str]):
        """
        allow modifying the arguments list before other classes' arguments are dynamically added
        this should be used sparsely, as it is hard to keep track of
        """
        # find last cls_network_body
        super().extend_args(args_list)

        # first find the correct config path, which is in all_args, enable short names (not only full paths)
        config_path = find_in_args_list(
            args_list,
            ['{cls_network}.config_path',
             '%s.config_path' % cls.__name__])
        config_path = Builder.find_net_config_path(config_path)

        # extract used classes from the network config file, add them to the current task config if missing
        used_classes = Builder().find_classes_in_config(config_path)
        network_name = used_classes['cls_network_body'][0]
        cls_network_body = Register.network_bodies.get(network_name)
        optional_meta = [
            m.argument.name for m in cls_network_body.meta_args_to_add()
            if m.optional_for_loading
        ]
        print(
            '\tbuilding a new net (config_only=False), added missing args from the network config file'
        )
        for cls_n in ['cls_network_body'] + optional_meta:
            cls_c = find_in_args_list(args_list, [cls_n])
            if cls_c is None or len(cls_c) == 0:
                cls_v = ', '.join(used_classes[cls_n])
                print('\t  %s -> %s' % (cls_n, cls_v))
                args_list.append('--%s=%s' % (cls_n, cls_v))
예제 #4
0
 def from_args(cls, args: Namespace, index=None) -> 'RetrainUninasNetwork':
     """
     :param args: global argparse namespace
     :param index: index for the args
     """
     all_parsed = cls._all_parsed_arguments(args, index=index)
     config_path = all_parsed.pop('config_path')
     config_path = Builder.find_net_config_path(config_path)
     net = Register.builder.load_from_config(config_path)
     return cls(model_name=Builder.net_config_name(config_path),
                net=net,
                **all_parsed)
예제 #5
0
def assert_stats_match(name,
                       task_cfg,
                       cfg: dict,
                       num_params=None,
                       num_macs=None):
    cfg_dir = replace_standard_paths('{path_tmp}/tests/cfgs/')
    cfg_path = Builder.save_config(cfg, cfg_dir, name)
    exp = Main.new_task(
        task_cfg,
        args_changes={
            '{cls_data}.fake': True,
            '{cls_data}.batch_size_train': 2,
            '{cls_data}.batch_size_test': -1,
            '{cls_task}.is_test_run': True,
            '{cls_task}.save_dir': '{path_tmp}/tests/workdir/',
            "{cls_network}.config_path": cfg_path,
            "{cls_trainer}.ema_decay": -1,
            'cls_network_heads':
            'ClassificationHead',  # necessary for the DARTS search space to disable the aux heads
        },
        raise_unparsed=False)
    net = exp.get_method().get_network()
    macs = exp.get_method().profile_macs()
    net.eval()
    # print(net)
    cp = count_parameters(net)
    if num_params is not None:
        assert cp == num_params, 'Got unexpected num params for %s: %d, expected %d, diff: %d'\
                                 % (name, cp, num_params, abs(cp - num_params))
    if num_macs is not None:
        assert macs == num_macs, 'Got unexpected num macs for %s: %d, expected %d, diff: %d'\
                                 % (name, macs, num_macs, abs(macs - num_macs))
예제 #6
0
    def new_task(cls,
                 cla: cla_type = None,
                 args_changes: dict = None,
                 raise_unparsed=True) -> AbstractTask:
        """
        :param cla:
            str: path to run_config file(s), (separated by commas), overrules other command line args if this exists
            list: specified command line arguments, overrules default system arguments
            None: use the system arguments
        :param args_changes: optional dictionary of changes to the command line arguments
        :param raise_unparsed: raise an exception if there are unparsed arguments left
        :return: new task as defined by the command line arguments
        """
        print("Creating new task")

        # reset any plotting
        plt.clf()
        plt.cla()

        # make sure everything is registered
        Builder()

        # from config file?
        if isinstance(cla, str):
            cla = arg_list_from_json(cla)
        print('-' * 50)

        # get arguments, insert args_changes
        args_list = sys.argv[1:] if cla is None else cla
        args_changes = args_changes if args_changes is not None else {}
        for k, v in args_changes.items():
            cla.append('--%s=%s' % (k, v))

        parser = argparse.ArgumentParser(description='UniNAS Project')
        node = ArgsTreeNode(Main)
        node.build_from_args(args_list, parser)
        args, wildcards, failed_args, descriptions = node.parse(
            args_list, parser, raise_unparsed=raise_unparsed)

        # note failed wildcards
        if len(failed_args) > 0:
            print('-' * 50)
            print('Failed replacements for argparse:')
            print(', '.join(failed_args))

        # list all wildcards for convenience
        print('-' * 50)
        print('Wildcard replacements for argparse:')
        for k, v in wildcards.items():
            print('\t{:<25} ->  {}'.format('{%s}' % k, v))
        print('-' * 50)

        # clean up, create and return the task
        Argument.reset_cached()
        cls_task = cls._parsed_meta_argument(Register.tasks,
                                             'cls_task',
                                             args,
                                             index=None)
        print('Starting %s!' % cls_task.__name__)
        return cls_task(args, wildcards, descriptions=descriptions)
예제 #7
0
def get_network(config_path: str, input_shape: Shape, output_shape: Shape, weights_path: str = None) -> AbstractUninasNetwork:
    """
    create a network (model) from a config file, optionally load weights
    """
    builder = Builder()

    # get a new network
    network = builder.load_from_config(Builder.find_net_config_path(config_path))
    network = AbstractUninasNetwork(model_name="standalone", net=network, checkpoint_path="", assert_output_match=True)
    network.build(s_in=input_shape, s_out=output_shape)

    # load network weights; they are saved from a method, so the keys have to be mapped accordingly
    if isinstance(weights_path, str):
        CheckpointCallback.load_network(weights_path, network, num_replacements=1)

    return network
예제 #8
0
 def test_rebuild(self):
     """
     getting finalized configs from which we can build modules
     """
     builder = Builder()
     StrategyManager().delete_strategy('default')
     StrategyManager().add_strategy(RandomChoiceStrategy(max_epochs=1))
     n, c, h, w = 2, 8, 16, 16
     x = torch.empty(size=[n, c, h, w])
     shape = Shape([c, h, w])
     layers = [
         FusedMobileInvertedConvLayer(name='mmicl',
                                      k_sizes=(3, 5, 7),
                                      expansions=(3, 6)),
         SuperConvThresholdLayer(k_sizes=(3, 5, 7)),
         SuperSepConvThresholdLayer(k_sizes=(3, 5, 7)),
         SuperMobileInvertedConvThresholdLayer(k_sizes=(3, 5, 7),
                                               expansions=(3, 6),
                                               sse_dict=dict(c_muls=(0.0,
                                                                     0.25,
                                                                     0.5))),
         LinearTransformerLayer(),
         SuperConvLayer(k_sizes=(3, 5, 7), name='scl1'),
         SuperSepConvLayer(k_sizes=(3, 5, 7), name='scl2'),
         SuperMobileInvertedConvLayer(k_sizes=(3, 5, 7),
                                      name='scl3',
                                      expansions=(2, 3, 4, 6)),
     ]
     for layer in layers:
         assert layer.build(shape, c) == shape
     StrategyManager().build()
     StrategyManager().forward()
     for layer in layers:
         print('\n' * 2)
         print(layer.__class__.__name__)
         for i in range(3):
             StrategyManager().randomize_weights()
             StrategyManager().forward()
             for finalize in [False, True]:
                 cfg = layer.config(finalize=finalize)
                 print('\t', i, 'finalize', finalize)
                 print('\t\tconfig dct:', cfg)
                 cfg_layer = builder.from_config(cfg)
                 assert cfg_layer.build(shape, c) == shape
                 cfg_layer.forward(x)
                 print('\t\tmodule str:', cfg_layer.str()[1:])
                 del cfg, cfg_layer
예제 #9
0
    def list_all_arguments(cls):
        """ list all arguments of all classes that expose arguments """

        Builder()
        all_to_list = [cls] + [item.value for item in Register.all.values()]
        arg_str = '  {name:<30}{type:<20}{default:<35}{help:<80}{choices}'

        def maybe_print_args(name: str, arguments: [Argument]) -> bool:
            if len(arguments) > 0:
                print('\n%s' % name)
                for a in arguments:
                    choices = ''
                    if isinstance(a.choices, (list, tuple)):
                        choices = [
                            ('"%s"' % c) if isinstance(c, str) else str(c)
                            for c in a.choices
                        ]
                        choices = 'choices=[%s]' % ', '.join(choices)
                    elif isinstance(a.registered, (list, tuple)):
                        choices = [
                            ('"%s"' % c) if isinstance(c, str) else str(c)
                            for c in a.registered
                        ]
                        choices = 'meta=[%s]' % ', '.join(choices)
                    print(
                        arg_str.format(
                            **{
                                'name': a.name,
                                'type': str(a.type),
                                'default': str(a.default),
                                'help': a.help,
                                'choices': choices,
                            }))
                return True
            return False

        print(
            '\n', '-' * 140, '\n',
            'these meta arguments influence which classes+arguments will be dynamically added:',
            '\n', '(the classes may have further arguments, listed below)',
            '\n', '-' * 140)
        for v in all_to_list:
            if isinstance(v, type) and issubclass(v, ArgsInterface):
                args = v.meta_args_to_add()
                maybe_print_args(v.__name__, [a.argument for a in args])

        print('\n', '-' * 140, '\n', 'all classes that have arguments:', '\n',
              '-' * 140)
        no_args = []
        for v in all_to_list:
            if isinstance(v, type) and issubclass(v, ArgsInterface):
                args = v.args_to_add()
                if not maybe_print_args(v.__name__, args):
                    no_args.append(v)

        print('\n', '-' * 140, '\n', 'classes that do not define arguments:',
              '\n', '-' * 140, '\n')
        for v in no_args:
            print(v.__name__)
예제 #10
0
 def save(net: nn.Module, save_dir: str, name: str, verbose=True) -> dict:
     # saving config now will only use the currently active connections, since we have a search network
     cfg = net.config(finalize=True)
     if (save_dir is not None) and (name is not None):
         path = Builder.save_config(cfg, save_dir, name)
         if verbose:
             print('Saved config: %s' % path)
     return cfg
예제 #11
0
 def from_args(cls,
               args: Namespace,
               index=None) -> 'RetrainInsertConfigUninasNetwork':
     """
     :param args: global argparse namespace
     :param index: argument index
     """
     all_parsed = cls._all_parsed_arguments(args, index=index)
     config_path = Builder.find_net_config_path(
         all_parsed.pop('config_path'))
     net = cls._parsed_meta_argument(Register.network_bodies,
                                     'cls_network_body',
                                     args,
                                     index=index)
     net = net.search_network_from_args(args, index=index)
     net.add_cells_from_config(Register.builder.load_config(config_path))
     return cls(model_name=Builder.net_config_name(config_path),
                net=net,
                **all_parsed)
예제 #12
0
    def profile(self, network: SearchUninasNetwork, mover: AbstractDeviceMover,
                batch_size: int):
        """ profile the network """
        assert self.profile_fun is not None, "Can not measure if there is no profile function!"
        sm = StrategyManager()

        # step 1) generate a dataset
        # at some point, if other predictors are attempted (nearest neighbor, SVM, ...) step1 code could be moved
        # to a shared parent class

        # number of choices at every position
        max_choices = sm.get_num_choices()
        print("max choices", max_choices)

        # get the search space, we can sample random architectures from it
        space = sm.get_value_space(unique=True)
        for i in range(10):
            print("random arc %d: %s" % (i, space.random_sample()))

        # make sure that a forward pass will not change the network topology
        network.set_forward_strategy(False)

        # find out the size of the network inputs
        shape_in = network.get_shape_in()

        # fix the network architecture, profile it
        sm.forward(fixed_arc=space.random_sample())
        value = self.profile_fun.profile(module=network,
                                         shape_in=shape_in,
                                         mover=mover,
                                         batch_size=batch_size)
        print('value 1', value)

        # alternate way: instead of using one over-complete network that has unused modules,
        # - get the current network architecture (the last set fixed_arc indices will be used now)
        # - build it stand-alone (exactly as the "true" network would be used later), with the same input/output sizes
        # - place it on the profiled device
        # - profile that instead
        # this takes longer, but the mismatch between over-complete and stand-alone is very interesting to explore
        # can make this an option via Argument
        network_config = network.config(finalize=True)
        network_body = Builder().from_config(network_config)
        standalone = RetrainUninasNetwork(model_name='__tmp__',
                                          net=network_body,
                                          checkpoint_path='',
                                          assert_output_match=True)
        standalone.build(network.get_shape_in(), network.get_shape_out()[0])
        standalone = mover.move_module(standalone)
        value = self.profile_fun.profile(module=standalone,
                                         shape_in=shape_in,
                                         mover=mover,
                                         batch_size=batch_size)
        print('value 2', value)
예제 #13
0
def get_dataset(data_kwargs: dict) -> AbstractDataSet:
    Builder()
    # get the data set
    parser = argparse.ArgumentParser()
    cls_data = Register.data_sets.get(data_kwargs.get('cls_data'))
    cls_data.add_arguments(parser, index=None)
    for i, cls_aug in enumerate([Register.augmentation_sets.get(cls) for cls in split(data_kwargs.get('cls_augmentations'))]):
        cls_aug.add_arguments(parser, index=i)
    data_args = parser.parse_args(args=[])
    for k, v in data_kwargs.items():
        data_args.__setattr__(k, v)
    return cls_data.from_args(data_args, index=None)
예제 #14
0
def generate_from_name(name: str, save=True, verbose=True):
    genotype, compact = compact_from_name(name, verbose=verbose)
    run_configs = '{path_conf_tasks}/d1_dartsv1.run_config, {path_conf_net_search}darts.run_config'
    # create weight sharing cell model
    changes = {
        'cls_data':
        'Cifar10Data',
        '{cls_data}.fake':
        True,
        '{cls_task}.save_del_old':
        False,
        '{cls_network_body}.cell_order':
        'n, r',
        '{cls_network_body}.features_first_cell':
        36 * 4,
        '{cls_network_stem}.features':
        36 * 3,
        'cls_network_cells_primitives':
        "%s, %s" % (compact.get('primitives'), compact.get('primitives')),
    }
    task = Main.new_task(run_configs, args_changes=changes)
    net = task.get_method().get_network()
    args = task.args

    wss = StrategyManager().get_strategies()
    assert len(wss) == 1
    ws = wss[list(wss.keys())[0]]

    # fix arc, all block inputs use different weights
    # go through all weights in the search cell
    for n, w in ws.named_parameters_single():
        # figure out cell type ("normal", "reduce"), block index, and if it's the first, second, ... op of that block
        c_type, block_idx, num_inputs, num_idx = n.split('/')[-4:]
        block_idx = int(block_idx.split('-')[-1])
        num_idx = int(num_idx.split('-')[-1])
        # set all paths weights to zero
        w.data.zero_()
        # go through the cell description of the genotype, if input and op number match, set the weight to be higher
        for op_idx, from_idx in compact.get(c_type)[block_idx]:
            if num_idx == from_idx:
                w[op_idx] = 1
    ws.forward()

    # saving config now will only use the highest weighted connections, since we have a search network
    cfg = net.config(finalize=True, num_block_ops=2)
    if save:
        path = Builder.save_config(cfg, get_net_config_dir(genotype.source),
                                   name)
        print('Saved config: %s' % path)
    return net, cfg, args
예제 #15
0
def get_dataset_from_json(path: str, fake=True) -> AbstractDataSet:
    """ parse a task config to re-create the used data set and augmentations """
    Builder()
    args_list = arg_list_from_json(path)
    args_list.append('--{cls_task}.save_dir=""')
    if fake:
        args_list.append('--{cls_data}.fake=True')
    parser = ArgumentParser("tmp")

    node = ArgsTreeNode(Main)
    node.build_from_args(args_list, parser)
    args, wildcards, failed_args, descriptions = node.parse(args_list, parser, raise_unparsed=True)

    return Register.data_sets.get(args.cls_data).from_args(args, index=None)
예제 #16
0
    def _build2(self, s_in: Shape, s_out: Shape) -> ShapeList:
        """ build the network """

        # find the search config
        if not os.path.isfile(self.search_config_path):
            self.search_config_path = Builder.find_net_config_path(
                self.search_config_path, pattern='search')

        # create a temporary search strategy
        tmp_s = RandomChoiceStrategy(max_epochs=1, name='__tmp__')
        sm = StrategyManager()
        assert len(sm.get_strategies_list(
        )) == 0, "can not load when there already is a search network"
        sm.add_strategy(tmp_s)
        sm.set_fixed_strategy_name('__tmp__')

        # create a search network
        search_net = Register.builder.load_from_config(self.search_config_path)
        assert isinstance(search_net, SearchUninasNetwork)
        search_net.build(s_in, s_out)
        search_net.set_forward_strategy(False)

        # set the architecture, get the config
        req_gene = ""
        if self.gene == 'random':
            search_net.forward_strategy()
            gene = sm.get_all_finalized_indices(unique=True, flat=True)
            self.model_name = "random(%s)" % str(gene)
            req_gene = " (%s)" % self.gene
        else:
            gene = split(self.gene, int)
        l0, l1 = len(sm.get_all_finalized_indices(unique=True)), len(gene)
        assert l0 == l1, "number of unique choices in the network (%d) must match length of the gene (%d)" % (
            l0, l1)
        search_net.forward_strategy(fixed_arc=gene)
        config = search_net.config(finalize=True)

        # clean up
        sm.delete_strategy('__tmp__')
        del sm
        del search_net

        # build the actually used finalized network
        LoggerManager().get_logger().info(
            "Extracting architecture %s%s from the super-network" %
            (gene, req_gene))
        self.net = Register.builder.from_config(config)
        return self.net.build(s_in, s_out)
예제 #17
0
def visualize_config(config: dict, save_path: str):
    save_path = replace_standard_paths(save_path)
    cfg_path = Builder.save_config(config, replace_standard_paths('{path_tmp}/viz/'), 'viz')
    exp = Main.new_task(run_config, args_changes={
        '{cls_data}.fake': True,
        '{cls_data}.batch_size_train': 4,
        '{cls_task}.is_test_run': True,
        '{cls_task}.save_dir': '{path_tmp}/viz/task/',
        '{cls_task}.save_del_old': True,
        "{cls_network}.config_path": cfg_path,
    })
    net = exp.get_method().get_network()
    vt = VizTree(net)
    vt.print()
    vt.plot(save_path + 'net', add_subgraphs=True)
    print('Saved cell viz to %s' % save_path)
예제 #18
0
def visualize_config(config: dict, save_path: str):
    save_path = replace_standard_paths(save_path)
    cfg_path = Builder.save_config(config, replace_standard_paths('{path_tmp}/viz/'), 'viz')
    exp = Main.new_task(run_config, args_changes={
        '{cls_data}.fake': True,
        '{cls_data}.batch_size_train': 2,
        '{cls_task}.is_test_run': True,
        '{cls_task}.save_dir': '{path_tmp}/viz/task/',
        '{cls_task}.save_del_old': True,
        "{cls_task}.note": "viz",
        "{cls_network}.config_path": cfg_path,
    })
    net = exp.get_method().get_network()
    for s in ['n', 'r']:
        for cell in net.get_cells():
            if cell.name.startswith(s):
                visualize_cell(cell, save_path, s)
                break
    print('Saved cell viz to %s' % save_path)
예제 #19
0
def main():
    Builder()
    parser = argparse.ArgumentParser(
        description='generate a network config from simple genotype description'
    )
    parser.add_argument('--genotypes',
                        type=str,
                        default=None,
                        help='which config to generate, all available if None')
    args = parser.parse_args()

    if args.genotypes is not None:
        all_genotype_names = [args.genotypes]
    else:
        all_genotype_names = []
        for key, value in list(globals().items()):
            if isinstance(value, NetWrapper):
                all_genotype_names.append(key)

    for genotype_name in all_genotype_names:
        print('Name:\t\t%s' % genotype_name)
        generate_from_name(genotype_name)
예제 #20
0
        self.node.print(indent=0)


def visualize_config(config: dict, save_path: str):
    save_path = replace_standard_paths(save_path)
    cfg_path = Builder.save_config(config, replace_standard_paths('{path_tmp}/viz/'), 'viz')
    exp = Main.new_task(run_config, args_changes={
        '{cls_data}.fake': True,
        '{cls_data}.batch_size_train': 4,
        '{cls_task}.is_test_run': True,
        '{cls_task}.save_dir': '{path_tmp}/viz/task/',
        '{cls_task}.save_del_old': True,
        "{cls_network}.config_path": cfg_path,
    })
    net = exp.get_method().get_network()
    vt = VizTree(net)
    vt.print()
    vt.plot(save_path + 'net', add_subgraphs=True)
    print('Saved cell viz to %s' % save_path)


def visualize_file(config_path: str, save_dir: str):
    config_name = config_path.split('/')[-1].split('.')[0]
    save_path = '%s%s/' % (save_dir, config_name)
    config = Builder.load_config(config_path)
    visualize_config(config, save_path)


if __name__ == '__main__':
    visualize_file(Builder.find_net_config_path('MobileNetV2'), '{path_tmp}/viz/')
예제 #21
0
"""
visualize the augmented data (not normalized, no batch-level augmentations (e.g. MixUp))
"""

import matplotlib.pyplot as plt
import numpy as np
from uninas.utils.torch.standalone import get_imagenet, get_imagenet16
from uninas.builder import Builder


if __name__ == '__main__':
    builder = Builder()

    num_img = 4
    num_transforms = 8
    train_data = True  # [True, False], train or test data (affects data augmentation)

    # get the data set

    data_set = get_imagenet(
        data_dir="{path_data}/ImageNet_ILSVRC2012/",
        batch_size=num_img,
        aug_dict={
            "cls_augmentations": "AAImagenetAug, CutoutAug",
            "DartsImagenetAug#0.crop_size": 224,
            "CutoutAug#1.size": 112,
        },
    )

    """
    data_set = get_imagenet16(
예제 #22
0
def visualize_args_tree(node: ArgsTreeNode):
    g = Digraph(format='pdf',
                engine='dot',
                edge_attr=dict(fontsize='20', fontname="times"),
                node_attr=dict(style='filled',
                               shape='rect',
                               align='center',
                               fontsize='20',
                               height='0.5',
                               penwidth='2',
                               fontname="times"))
    _visualize_args_tree(node, g)
    return g


if __name__ == '__main__':
    from uninas.builder import Builder
    Builder()

    args_list = arg_list_from_json("/tmp/uninas/s1/task.run_config")

    root = ArgsTreeNode(Main)
    root.build_from_args(args_list)
    print("-" * 200)
    visualize_args_tree(root).view(
        filename="args_tree",
        directory=replace_standard_paths("{path_tmp}"),
        cleanup=True,
        quiet_view=True)
예제 #23
0
    def _initialize_weights(self, net: AbstractModule, logger: logging.Logger):
        assert isinstance(
            net, AbstractUninasNetwork
        ), "This initializer will not work with external networks!"
        search_config = Builder.find_net_config_path(self.path,
                                                     pattern='search')

        checkpoint = CheckpointCallback.load_last_checkpoint(self.path)
        state_dict = checkpoint.get('state_dict')

        # figure out correct weights in super-network checkpoint
        if len(self.gene) > 0:
            log_headline(logger,
                         "tmp network to track used params",
                         target_len=80)
            sm = StrategyManager()
            tmp_s = RandomChoiceStrategy(max_epochs=1, name='__tmp__')
            assert len(sm.get_strategies_list(
            )) == 0, "can not load when there already is a search network"
            sm.add_strategy(tmp_s)
            sm.set_fixed_strategy_name('__tmp__')

            search_net = Builder().load_from_config(search_config)
            assert isinstance(search_net, SearchUninasNetwork)
            s_in, s_out = net.get_shape_in(), net.get_shape_out()
            search_net.build(s_in, s_out[0])
            search_net.set_forward_strategy(False)
            search_net.forward_strategy(fixed_arc=self.gene)
            tracker = search_net.track_used_params(
                s_in.random_tensor(batch_size=2))
            # tracker.print()

            logger.info(' > loading weights of gene %s from checkpoint "%s"' %
                        (str(self.gene), self.path))
            target_dict = net.state_dict()
            target_names = list(target_dict.keys())
            new_dict = {}

            # add all stem and head weights, they are at the front of the dict and have pretty much the same name
            log_columns = [('shape in checkpoint', 'name in checkpoint',
                            'name in network', 'shape in network')]
            for k, v in state_dict.items():
                if '.stem.' in k or '.heads.' in k:
                    tn = target_names.pop(0)
                    ts = target_dict[tn].shape
                    log_columns.append(
                        (str(list(v.shape)), k, tn, str(list(ts))))
                    n = k.replace('net.', '', 1)
                    assert n == tn
                    new_dict[n] = v

            # add all cell weights, can generally not compare names, only shapes
            for i, tracker_cell_entry in enumerate(tracker.get_cells()):
                for entry in tracker_cell_entry.get_pareto_best():
                    tn = target_names.pop(0)
                    ts = target_dict[tn].shape
                    log_columns.append((str(list(entry.shape)), entry.name, tn,
                                        str(list(ts))))
                    assert entry.shape == ts,\
                        'Mismatching shapes for "%s" and "%s", is the gene correct?' % (entry.name, tn)
                    new_dict[tn] = state_dict[entry.name]

            # log matches, load
            log_in_columns(logger, log_columns, add_bullets=True)
            net.load_state_dict(new_dict, strict=self.strict)

            # clean up
            del search_net
            sm.delete_strategy('__tmp__')
            del sm

        # simply load
        else:
            logger.info(' > simply loading state_dict')
            net.load_state_dict(state_dict, strict=self.strict)
예제 #24
0
    def test_output_shapes(self):
        """
        expected output shapes of standard layers
        """
        Builder()
        StrategyManager().delete_strategy('default')
        StrategyManager().add_strategy(RandomChoiceStrategy(max_epochs=1))

        bs, c1, c2, hw1, hw2 = 4, 4, 8, 32, 16
        s_in = Shape([c1, hw1, hw1])
        x = torch.empty(size=[bs] + s_in.shape)

        case_s1_c1 = (c1, 1, Shape([c1, hw1, hw1]))
        case_s1_c2 = (c2, 1, Shape([c2, hw1, hw1]))
        case_s2_c1 = (c1, 2, Shape([c1, hw2, hw2]))
        case_s2_c2 = (c2, 2, Shape([c2, hw2, hw2]))

        for cls, cases, kwargs in [
            (SkipLayer, [case_s1_c1, case_s1_c2], dict()),
            (ZeroLayer, [case_s1_c1, case_s1_c2, case_s2_c1,
                         case_s2_c2], dict()),
            (FactorizedReductionLayer, [case_s2_c1, case_s2_c2], dict()),
            (PoolingLayer, [case_s1_c1, case_s1_c2, case_s2_c1,
                            case_s2_c2], dict(k_size=3)),
            (ConvLayer, [case_s1_c1, case_s1_c2, case_s2_c1,
                         case_s2_c2], dict(k_size=3)),
            (SepConvLayer, [case_s1_c1, case_s1_c2, case_s2_c1,
                            case_s2_c2], dict(k_size=3)),
            (MobileInvertedConvLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1, case_s2_c2], dict(k_size=3)),
            (MobileInvertedConvLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1,
              case_s2_c2], dict(k_size=(3, ))),
            (MobileInvertedConvLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1, case_s2_c2],
             dict(k_size=(3, 5, 7), k_size_in=(1, 1), k_size_out=(1, 1))),
            (FusedMobileInvertedConvLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1, case_s2_c2],
             dict(name='mmicl1',
                  k_sizes=(3, 5, 7),
                  k_size_in=(1, 1),
                  k_size_out=(1, 1))),
            (FusedMobileInvertedConvLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1, case_s2_c2],
             dict(name='mmicl2',
                  k_sizes=((3, 5), (3, 5, 7)),
                  k_size_in=(1, 1),
                  k_size_out=(1, 1))),
            (ShuffleNetV2Layer, [case_s1_c1, case_s1_c2,
                                 case_s2_c2], dict(k_size=3)),
            (ShuffleNetV2XceptionLayer, [case_s1_c1, case_s1_c2,
                                         case_s2_c2], dict(k_size=3)),
            (LinearTransformerLayer, [case_s1_c1, case_s1_c2], dict()),
            (SuperConvThresholdLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1,
              case_s2_c2], dict(k_sizes=(3, 5, 7))),
            (SuperSepConvThresholdLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1,
              case_s2_c2], dict(k_sizes=(3, 5, 7))),
            (SuperMobileInvertedConvThresholdLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1, case_s2_c2],
             dict(k_sizes=(3, 5, 7),
                  expansions=(3, 6),
                  sse_dict=dict(c_muls=(0.0, 0.25, 0.5)))),
            (SuperConvLayer, [case_s1_c1, case_s1_c2, case_s2_c1,
                              case_s2_c2], dict(k_sizes=(3, 5, 7),
                                                name='scl')),
            (SuperSepConvLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1,
              case_s2_c2], dict(k_sizes=(3, 5, 7), name='sscl')),
            (SuperMobileInvertedConvLayer,
             [case_s1_c1, case_s1_c2, case_s2_c1, case_s2_c2],
             dict(k_sizes=(3, 5, 7), name='smicl', expansions=(3, 6))),
            (AttentionLayer, [case_s1_c1],
             dict(att_dict=dict(att_cls='EfficientChannelAttentionModule'))),
            (AttentionLayer, [case_s1_c1],
             dict(att_dict=dict(att_cls='SqueezeExcitationChannelModule'))),
        ]:
            for c, stride, shape_out in cases:
                m1 = cls(stride=stride, **kwargs)
                s_out = m1.build(s_in, c)
                assert s_out == shape_out, 'Expected output shape does not match, %s, build=%s / expected=%s' %\
                                           (cls.__name__, s_out, shape_out)
                assert_output_shape(m1, x, [bs] + shape_out.shape)
                print('%s(stride=%d, c_in=%d, c_out=%d)' %
                      (cls.__name__, stride, c1, c))
예제 #25
0
def visualize_config(config: dict, save_path: str):
    save_path = replace_standard_paths(save_path)
    cfg_path = Builder.save_config(config, replace_standard_paths('{path_tmp}/viz/'), 'viz')
    exp = Main.new_task(run_config, args_changes={
        '{cls_data}.fake': True,
        '{cls_data}.batch_size_train': 2,
        '{cls_task}.is_test_run': True,
        '{cls_task}.save_dir': '{path_tmp}/viz/task/',
        '{cls_task}.save_del_old': True,
        "{cls_task}.note": "viz",
        "{cls_network}.config_path": cfg_path,
    })
    net = exp.get_method().get_network()
    for s in ['n', 'r']:
        for cell in net.get_cells():
            if cell.name.startswith(s):
                visualize_cell(cell, save_path, s)
                break
    print('Saved cell viz to %s' % save_path)


def visualize_file(config_path: str, save_dir: str):
    config_name_ = Builder.net_config_name(config_path)
    save_path = save_dir+config_name_+'/'
    config = Builder.load_config(config_path)
    visualize_config(config, save_path)


if __name__ == '__main__':
    visualize_file(Builder.find_net_config_path('DARTS_V1'), '{path_tmp}/viz/')
예제 #26
0
    def make_from_single_dir(cls, path: str, space_name: str,
                             arch_index: int) -> MiniResult:
        """
        creating a mini result by parsing a training process
        """

        # find gene and dataset in the task config
        task_configs = find_all_files(path, extension=name_task_config)
        assert len(task_configs) == 1

        with open(task_configs[0]) as config_file:
            config = json.load(config_file)
            gene = config.get('{cls_network}.gene')
            gene = split(gene, int)

            data_set = get_dataset_from_json(task_configs[0])
            data_set_name = data_set.__class__.__name__

        # find loss and acc in the tensorboard files
        average_last = 5
        metric_accuracy_train, metric_loss_train = "train/accuracy/1", "train/loss"
        metric_accuracy_test, metric_loss_test = "test/accuracy/1", "test/loss"

        tb_files = find_tb_files(path)
        assert len(tb_files) > 0
        events = read_event_files(tb_files)

        loss_train = events.get(metric_loss_train, None)
        loss_test = events.get(metric_loss_test, None)
        assert (loss_train is not None) and (loss_test is not None)
        accuracy_train = events.get(metric_accuracy_train, None)
        accuracy_test = events.get(metric_accuracy_test, None)
        assert (accuracy_train is not None) and (accuracy_test is not None)

        # figure out params and flops by building the network
        net_config_path = Builder.find_net_config_path(path)
        network = get_network(net_config_path, data_set.get_data_shape(),
                              data_set.get_label_shape())

        # figure out latency at some point
        pass

        # return result
        return MiniResult(
            arch_index=arch_index,
            arch_str="%s(%s)" % (space_name, ", ".join([str(g)
                                                        for g in gene])),
            arch_tuple=tuple(gene),
            params={data_set_name: network.get_num_parameters()},
            flops={data_set_name: network.profile_macs()},
            latency={data_set_name: -1},
            loss={
                data_set_name: {
                    'train':
                    np.mean([v.value for v in loss_train[-average_last:]]),
                    'test':
                    np.mean([v.value for v in loss_test[-average_last:]]),
                }
            },
            acc1={
                data_set_name: {
                    'train':
                    np.mean([v.value for v in accuracy_train[-average_last:]]),
                    'test':
                    np.mean([v.value for v in accuracy_test[-average_last:]]),
                }
            },
        )
예제 #27
0
    def test_model_params(self):
        """
        make sure that the models from known network_configs / genotypes have correct number of params
        Imagenet1k uses input size (3, 224, 224)

        differences to originals:
            (1) they use x.mean(3).mean(2) while we use nn.AdaptiveAvgPool2d(1)
            (2) We use the new torch 1.6 swish/hswish/htanh/hsigmoid activation functions
            (3) TODO marginal macs difference for SPOS after changing search-network/primitives code (maybe act fun?)
        the numbers matched exactly when these were accounted for
        """
        Builder()

        # measured via torch.hub
        assert_super_stats_match('ResNet18',
                                 num_params=11689512,
                                 num_macs=1814073856)
        assert_super_stats_match('ResNet34',
                                 num_params=21797672,
                                 num_macs=3663761920)
        assert_super_stats_match('ResNet50',
                                 num_params=25557032,
                                 num_macs=4089186304)
        assert_super_stats_match('ResNet101',
                                 num_params=44549160,
                                 num_macs=7801407488)
        assert_super_stats_match('MobileNetV2',
                                 num_params=3504872,
                                 num_macs=300775552)

        # measured via https://github.com/megvii-model/SinglePathOneShot
        assert_super_stats_match('SPOSNet',
                                 num_params=3558464,
                                 num_macs=322919776 - 16)  # (3)

        # measured via https://github.com/megvii-model/ShuffleNet-Series
        assert_super_stats_match('ShuffleNetV2PlusMedium',
                                 num_params=5679840,
                                 num_macs=224038432 - 1531648 - 16)  # (2), (3)

        # measured via https://github.com/rwightman/pytorch-image-models
        # requires replacing the swish function, otherwise torchprofile tracing fails
        assert_super_stats_match('EfficientNetB0',
                                 num_params=5288548,
                                 num_macs=394289436)
        assert_super_stats_match('MobileNetV3Large100',
                                 num_params=5483032,
                                 num_macs=218703448 - 1511264)  # (2)
        assert_super_stats_match('MobileNetV3Small100',
                                 num_params=2542856,
                                 num_macs=57597784 - 799136)  # (2)
        assert_super_stats_match('MixNetS', num_params=4134606, num_macs=None)
        assert_super_stats_match('MixNetM', num_params=5014382, num_macs=None)

        # measured via https://github.com/mit-han-lab/proxylessnas
        assert_super_stats_match('ProxylessRMobile',
                                 num_params=4080512,
                                 num_macs=320428864)

        # measured via https://github.com/xiaomi-automl/FairNAS
        assert_super_stats_match('FairNasA',
                                 num_params=4651352,
                                 num_macs=388133088 - 8960)  # (1)
        assert_super_stats_match('FairNasB',
                                 num_params=4506272,
                                 num_macs=345307872 - 8960)  # (1)
        assert_super_stats_match('FairNasC',
                                 num_params=4397864,
                                 num_macs=321035232 - 8960)  # (1)

        # measured via https://github.com/xiaomi-automl/SCARLET-NAS
        assert_super_stats_match('ScarletNasA',
                                 num_params=6707720,
                                 num_macs=370705152 - 8960 - 4572288)  # (1, 2)
        assert_super_stats_match('ScarletNasB',
                                 num_params=6531556,
                                 num_macs=332082096 - 8960 - 3932544)  # (1, 2)
        assert_super_stats_match('ScarletNasC',
                                 num_params=6073684,
                                 num_macs=284388720 - 8960 - 3278688)  # (1, 2)

        # measured via https://github.com/cogsys-tuebingen/prdarts
        assert_darts_cifar10_stats_match('DARTS_V1',
                                         num_params=3169414,
                                         num_macs=501015744)
        assert_darts_cifar10_stats_match('PDARTS',
                                         num_params=3433798,
                                         num_macs=532202688)
        assert_darts_cifar10_stats_match('PR_DARTS_DL1',
                                         num_params=3174166,
                                         num_macs=491200704)
        assert_darts_cifar10_stats_match('PR_DARTS_DL2',
                                         num_params=4017646,
                                         num_macs=650370240)
        assert_darts_imagenet_stats_match('DARTS_V1',
                                          num_params=4510432,
                                          num_macs=505893696)
        assert_darts_imagenet_stats_match('PDARTS',
                                          num_params=4944352,
                                          num_macs=542848320)
        assert_darts_imagenet_stats_match('PR_DARTS_DL1',
                                          num_params=4685152,
                                          num_macs=509384064)
        assert_darts_imagenet_stats_match('PR_DARTS_DL2',
                                          num_params=5529856,
                                          num_macs=631603392)

        # measured via https://github.com/tanglang96/MDENAS
        assert_darts_cifar10_stats_match('MdeNAS',
                                         num_params=3786742,
                                         num_macs=599110848)
        assert_darts_imagenet_stats_match('MdeNAS',
                                          num_params=5329024,
                                          num_macs=595514304)
예제 #28
0
def visualize_file(config_path: str, save_dir: str):
    config_name_ = Builder.net_config_name(config_path)
    save_path = save_dir+config_name_+'/'
    config = Builder.load_config(config_path)
    visualize_config(config, save_path)
예제 #29
0
 def __init__(self, hooks=()):
     Builder()
     self.root = GuiArgsTreeNode(Main)
     for hook in hooks:
         self.root.add_hook(hook)
예제 #30
0
def visualize_file(config_path: str, save_dir: str):
    config_name = config_path.split('/')[-1].split('.')[0]
    save_path = '%s%s/' % (save_dir, config_name)
    config = Builder.load_config(config_path)
    visualize_config(config, save_path)