コード例 #1
0
ファイル: retrain.py プロジェクト: Light-Reflection/uninas
    def extend_args(cls, args_list: [str]):
        """
        allow modifying the arguments list before other classes' arguments are dynamically added
        this should be used sparsely, as it is hard to keep track of
        """
        # find last cls_network_body
        super().extend_args(args_list)

        # first find the correct config path, which is in all_args, enable short names (not only full paths)
        config_path = find_in_args_list(
            args_list,
            ['{cls_network}.config_path',
             '%s.config_path' % cls.__name__])
        config_path = Builder.find_net_config_path(config_path)

        # extract used classes from the network config file, add them to the current task config if missing
        used_classes = Builder().find_classes_in_config(config_path)
        network_name = used_classes['cls_network_body'][0]
        cls_network_body = Register.network_bodies.get(network_name)
        optional_meta = [
            m.argument.name for m in cls_network_body.meta_args_to_add()
            if m.optional_for_loading
        ]
        print(
            '\tbuilding a new net (config_only=False), added missing args from the network config file'
        )
        for cls_n in ['cls_network_body'] + optional_meta:
            cls_c = find_in_args_list(args_list, [cls_n])
            if cls_c is None or len(cls_c) == 0:
                cls_v = ', '.join(used_classes[cls_n])
                print('\t  %s -> %s' % (cls_n, cls_v))
                args_list.append('--%s=%s' % (cls_n, cls_v))
コード例 #2
0
ファイル: retrain.py プロジェクト: Light-Reflection/uninas
 def from_args(cls, args: Namespace, index=None) -> 'RetrainUninasNetwork':
     """
     :param args: global argparse namespace
     :param index: index for the args
     """
     all_parsed = cls._all_parsed_arguments(args, index=index)
     config_path = all_parsed.pop('config_path')
     config_path = Builder.find_net_config_path(config_path)
     net = Register.builder.load_from_config(config_path)
     return cls(model_name=Builder.net_config_name(config_path),
                net=net,
                **all_parsed)
コード例 #3
0
ファイル: retrain.py プロジェクト: Light-Reflection/uninas
    def _build2(self, s_in: Shape, s_out: Shape) -> ShapeList:
        """ build the network """

        # find the search config
        if not os.path.isfile(self.search_config_path):
            self.search_config_path = Builder.find_net_config_path(
                self.search_config_path, pattern='search')

        # create a temporary search strategy
        tmp_s = RandomChoiceStrategy(max_epochs=1, name='__tmp__')
        sm = StrategyManager()
        assert len(sm.get_strategies_list(
        )) == 0, "can not load when there already is a search network"
        sm.add_strategy(tmp_s)
        sm.set_fixed_strategy_name('__tmp__')

        # create a search network
        search_net = Register.builder.load_from_config(self.search_config_path)
        assert isinstance(search_net, SearchUninasNetwork)
        search_net.build(s_in, s_out)
        search_net.set_forward_strategy(False)

        # set the architecture, get the config
        req_gene = ""
        if self.gene == 'random':
            search_net.forward_strategy()
            gene = sm.get_all_finalized_indices(unique=True, flat=True)
            self.model_name = "random(%s)" % str(gene)
            req_gene = " (%s)" % self.gene
        else:
            gene = split(self.gene, int)
        l0, l1 = len(sm.get_all_finalized_indices(unique=True)), len(gene)
        assert l0 == l1, "number of unique choices in the network (%d) must match length of the gene (%d)" % (
            l0, l1)
        search_net.forward_strategy(fixed_arc=gene)
        config = search_net.config(finalize=True)

        # clean up
        sm.delete_strategy('__tmp__')
        del sm
        del search_net

        # build the actually used finalized network
        LoggerManager().get_logger().info(
            "Extracting architecture %s%s from the super-network" %
            (gene, req_gene))
        self.net = Register.builder.from_config(config)
        return self.net.build(s_in, s_out)
コード例 #4
0
ファイル: standalone.py プロジェクト: Light-Reflection/uninas
def get_network(config_path: str, input_shape: Shape, output_shape: Shape, weights_path: str = None) -> AbstractUninasNetwork:
    """
    create a network (model) from a config file, optionally load weights
    """
    builder = Builder()

    # get a new network
    network = builder.load_from_config(Builder.find_net_config_path(config_path))
    network = AbstractUninasNetwork(model_name="standalone", net=network, checkpoint_path="", assert_output_match=True)
    network.build(s_in=input_shape, s_out=output_shape)

    # load network weights; they are saved from a method, so the keys have to be mapped accordingly
    if isinstance(weights_path, str):
        CheckpointCallback.load_network(weights_path, network, num_replacements=1)

    return network
コード例 #5
0
ファイル: retrain.py プロジェクト: Light-Reflection/uninas
 def from_args(cls,
               args: Namespace,
               index=None) -> 'RetrainInsertConfigUninasNetwork':
     """
     :param args: global argparse namespace
     :param index: argument index
     """
     all_parsed = cls._all_parsed_arguments(args, index=index)
     config_path = Builder.find_net_config_path(
         all_parsed.pop('config_path'))
     net = cls._parsed_meta_argument(Register.network_bodies,
                                     'cls_network_body',
                                     args,
                                     index=index)
     net = net.search_network_from_args(args, index=index)
     net.add_cells_from_config(Register.builder.load_config(config_path))
     return cls(model_name=Builder.net_config_name(config_path),
                net=net,
                **all_parsed)
コード例 #6
0
    def _initialize_weights(self, net: AbstractModule, logger: logging.Logger):
        assert isinstance(
            net, AbstractUninasNetwork
        ), "This initializer will not work with external networks!"
        search_config = Builder.find_net_config_path(self.path,
                                                     pattern='search')

        checkpoint = CheckpointCallback.load_last_checkpoint(self.path)
        state_dict = checkpoint.get('state_dict')

        # figure out correct weights in super-network checkpoint
        if len(self.gene) > 0:
            log_headline(logger,
                         "tmp network to track used params",
                         target_len=80)
            sm = StrategyManager()
            tmp_s = RandomChoiceStrategy(max_epochs=1, name='__tmp__')
            assert len(sm.get_strategies_list(
            )) == 0, "can not load when there already is a search network"
            sm.add_strategy(tmp_s)
            sm.set_fixed_strategy_name('__tmp__')

            search_net = Builder().load_from_config(search_config)
            assert isinstance(search_net, SearchUninasNetwork)
            s_in, s_out = net.get_shape_in(), net.get_shape_out()
            search_net.build(s_in, s_out[0])
            search_net.set_forward_strategy(False)
            search_net.forward_strategy(fixed_arc=self.gene)
            tracker = search_net.track_used_params(
                s_in.random_tensor(batch_size=2))
            # tracker.print()

            logger.info(' > loading weights of gene %s from checkpoint "%s"' %
                        (str(self.gene), self.path))
            target_dict = net.state_dict()
            target_names = list(target_dict.keys())
            new_dict = {}

            # add all stem and head weights, they are at the front of the dict and have pretty much the same name
            log_columns = [('shape in checkpoint', 'name in checkpoint',
                            'name in network', 'shape in network')]
            for k, v in state_dict.items():
                if '.stem.' in k or '.heads.' in k:
                    tn = target_names.pop(0)
                    ts = target_dict[tn].shape
                    log_columns.append(
                        (str(list(v.shape)), k, tn, str(list(ts))))
                    n = k.replace('net.', '', 1)
                    assert n == tn
                    new_dict[n] = v

            # add all cell weights, can generally not compare names, only shapes
            for i, tracker_cell_entry in enumerate(tracker.get_cells()):
                for entry in tracker_cell_entry.get_pareto_best():
                    tn = target_names.pop(0)
                    ts = target_dict[tn].shape
                    log_columns.append((str(list(entry.shape)), entry.name, tn,
                                        str(list(ts))))
                    assert entry.shape == ts,\
                        'Mismatching shapes for "%s" and "%s", is the gene correct?' % (entry.name, tn)
                    new_dict[tn] = state_dict[entry.name]

            # log matches, load
            log_in_columns(logger, log_columns, add_bullets=True)
            net.load_state_dict(new_dict, strict=self.strict)

            # clean up
            del search_net
            sm.delete_strategy('__tmp__')
            del sm

        # simply load
        else:
            logger.info(' > simply loading state_dict')
            net.load_state_dict(state_dict, strict=self.strict)
コード例 #7
0
    def make_from_single_dir(cls, path: str, space_name: str,
                             arch_index: int) -> MiniResult:
        """
        creating a mini result by parsing a training process
        """

        # find gene and dataset in the task config
        task_configs = find_all_files(path, extension=name_task_config)
        assert len(task_configs) == 1

        with open(task_configs[0]) as config_file:
            config = json.load(config_file)
            gene = config.get('{cls_network}.gene')
            gene = split(gene, int)

            data_set = get_dataset_from_json(task_configs[0])
            data_set_name = data_set.__class__.__name__

        # find loss and acc in the tensorboard files
        average_last = 5
        metric_accuracy_train, metric_loss_train = "train/accuracy/1", "train/loss"
        metric_accuracy_test, metric_loss_test = "test/accuracy/1", "test/loss"

        tb_files = find_tb_files(path)
        assert len(tb_files) > 0
        events = read_event_files(tb_files)

        loss_train = events.get(metric_loss_train, None)
        loss_test = events.get(metric_loss_test, None)
        assert (loss_train is not None) and (loss_test is not None)
        accuracy_train = events.get(metric_accuracy_train, None)
        accuracy_test = events.get(metric_accuracy_test, None)
        assert (accuracy_train is not None) and (accuracy_test is not None)

        # figure out params and flops by building the network
        net_config_path = Builder.find_net_config_path(path)
        network = get_network(net_config_path, data_set.get_data_shape(),
                              data_set.get_label_shape())

        # figure out latency at some point
        pass

        # return result
        return MiniResult(
            arch_index=arch_index,
            arch_str="%s(%s)" % (space_name, ", ".join([str(g)
                                                        for g in gene])),
            arch_tuple=tuple(gene),
            params={data_set_name: network.get_num_parameters()},
            flops={data_set_name: network.profile_macs()},
            latency={data_set_name: -1},
            loss={
                data_set_name: {
                    'train':
                    np.mean([v.value for v in loss_train[-average_last:]]),
                    'test':
                    np.mean([v.value for v in loss_test[-average_last:]]),
                }
            },
            acc1={
                data_set_name: {
                    'train':
                    np.mean([v.value for v in accuracy_train[-average_last:]]),
                    'test':
                    np.mean([v.value for v in accuracy_test[-average_last:]]),
                }
            },
        )
コード例 #8
0
        self.node.print(indent=0)


def visualize_config(config: dict, save_path: str):
    save_path = replace_standard_paths(save_path)
    cfg_path = Builder.save_config(config, replace_standard_paths('{path_tmp}/viz/'), 'viz')
    exp = Main.new_task(run_config, args_changes={
        '{cls_data}.fake': True,
        '{cls_data}.batch_size_train': 4,
        '{cls_task}.is_test_run': True,
        '{cls_task}.save_dir': '{path_tmp}/viz/task/',
        '{cls_task}.save_del_old': True,
        "{cls_network}.config_path": cfg_path,
    })
    net = exp.get_method().get_network()
    vt = VizTree(net)
    vt.print()
    vt.plot(save_path + 'net', add_subgraphs=True)
    print('Saved cell viz to %s' % save_path)


def visualize_file(config_path: str, save_dir: str):
    config_name = config_path.split('/')[-1].split('.')[0]
    save_path = '%s%s/' % (save_dir, config_name)
    config = Builder.load_config(config_path)
    visualize_config(config, save_path)


if __name__ == '__main__':
    visualize_file(Builder.find_net_config_path('MobileNetV2'), '{path_tmp}/viz/')
コード例 #9
0
def visualize_config(config: dict, save_path: str):
    save_path = replace_standard_paths(save_path)
    cfg_path = Builder.save_config(config, replace_standard_paths('{path_tmp}/viz/'), 'viz')
    exp = Main.new_task(run_config, args_changes={
        '{cls_data}.fake': True,
        '{cls_data}.batch_size_train': 2,
        '{cls_task}.is_test_run': True,
        '{cls_task}.save_dir': '{path_tmp}/viz/task/',
        '{cls_task}.save_del_old': True,
        "{cls_task}.note": "viz",
        "{cls_network}.config_path": cfg_path,
    })
    net = exp.get_method().get_network()
    for s in ['n', 'r']:
        for cell in net.get_cells():
            if cell.name.startswith(s):
                visualize_cell(cell, save_path, s)
                break
    print('Saved cell viz to %s' % save_path)


def visualize_file(config_path: str, save_dir: str):
    config_name_ = Builder.net_config_name(config_path)
    save_path = save_dir+config_name_+'/'
    config = Builder.load_config(config_path)
    visualize_config(config, save_path)


if __name__ == '__main__':
    visualize_file(Builder.find_net_config_path('DARTS_V1'), '{path_tmp}/viz/')