Esempio n. 1
0
    def profile(self, network: SearchUninasNetwork, mover: AbstractDeviceMover,
                batch_size: int):
        """ profile the network """
        assert self.profile_fun is not None, "Can not measure if there is no profile function!"
        sm = StrategyManager()

        # step 1) generate a dataset
        # at some point, if other predictors are attempted (nearest neighbor, SVM, ...) step1 code could be moved
        # to a shared parent class

        # number of choices at every position
        max_choices = sm.get_num_choices()
        print("max choices", max_choices)

        # get the search space, we can sample random architectures from it
        space = sm.get_value_space(unique=True)
        for i in range(10):
            print("random arc %d: %s" % (i, space.random_sample()))

        # make sure that a forward pass will not change the network topology
        network.set_forward_strategy(False)

        # find out the size of the network inputs
        shape_in = network.get_shape_in()

        # fix the network architecture, profile it
        sm.forward(fixed_arc=space.random_sample())
        value = self.profile_fun.profile(module=network,
                                         shape_in=shape_in,
                                         mover=mover,
                                         batch_size=batch_size)
        print('value 1', value)

        # alternate way: instead of using one over-complete network that has unused modules,
        # - get the current network architecture (the last set fixed_arc indices will be used now)
        # - build it stand-alone (exactly as the "true" network would be used later), with the same input/output sizes
        # - place it on the profiled device
        # - profile that instead
        # this takes longer, but the mismatch between over-complete and stand-alone is very interesting to explore
        # can make this an option via Argument
        network_config = network.config(finalize=True)
        network_body = Builder().from_config(network_config)
        standalone = RetrainUninasNetwork(model_name='__tmp__',
                                          net=network_body,
                                          checkpoint_path='',
                                          assert_output_match=True)
        standalone.build(network.get_shape_in(), network.get_shape_out()[0])
        standalone = mover.move_module(standalone)
        value = self.profile_fun.profile(module=standalone,
                                         shape_in=shape_in,
                                         mover=mover,
                                         batch_size=batch_size)
        print('value 2', value)
Esempio n. 2
0
class SearchUninasNetwork(AbstractUninasNetwork):

    def __init__(self, model_name: str, net: AbstractNetworkBody, do_forward_strategy=True, *args, **kwargs):
        super().__init__(model_name=model_name, net=net, *args, **kwargs)
        self.do_forward_strategy = do_forward_strategy  # unnecessary line to remove "error" highlighting
        self._add_to_kwargs(do_forward_strategy=self.do_forward_strategy)
        self.strategy_manager = StrategyManager()
        self.strategies = None

    @classmethod
    def from_args(cls, args: Namespace, index=None, weight_strategies: Union[dict, str] = None)\
            -> 'SearchUninasNetwork':
        """
        :param args: global argparse namespace
        :param index: argument index
        :param weight_strategies: {strategy name: [cell indices]}, or name used for all, or None for defaults
        """
        all_parsed = cls._all_parsed_arguments(args)
        cls_net = cls._parsed_meta_argument(Register.network_bodies, 'cls_network_body', args, index=index)
        net = cls_net.search_network_from_args(args, index=index, weight_strategies=weight_strategies)
        return cls(cls.__name__, net, **all_parsed)

    @classmethod
    def meta_args_to_add(cls) -> [MetaArgument]:
        """
        list meta arguments to add to argparse for when this class is chosen,
        classes specified in meta arguments may have their own respective arguments
        """
        return super().meta_args_to_add() + [
            MetaArgument('cls_network_body', Register.network_bodies, help_name='network', allowed_num=1),
        ]

    def _build2(self, s_in: Shape, s_out: Shape) -> ShapeList:
        """ build the network """
        s = self.net.build(s_in, s_out)
        self.strategies = self.strategy_manager.get_strategies_list()
        self.strategy_manager.build()
        return s

    def get_strategy_manager(self) -> StrategyManager:
        return self.strategy_manager

    def set_forward_strategy(self, forward_strategy: bool):
        self.do_forward_strategy = forward_strategy

    def get_forward_strategy(self) -> bool:
        return self.do_forward_strategy

    def forward(self, x: torch.Tensor, ws_kwargs: dict = None, **net_kwargs) -> [torch.Tensor]:
        """
        forward first the weight strategy, then the network
        """
        if self.do_forward_strategy:
            self.forward_strategy(**({} if ws_kwargs is None else ws_kwargs))
        return super().forward(x, **net_kwargs)

    def forward_net(self, x: torch.Tensor, **net_kwargs) -> [torch.Tensor]:
        """
        forward only the network
        """
        return self.net(x, **net_kwargs)

    def forward_strategy(self, **ws_kwargs):
        """
        forward only the weight strategy
        """
        self.strategy_manager.forward(**ws_kwargs)

    def str(self, depth=0, **_) -> str:
        r = '{d}{name}(\n{ws},{net}\n{d}])'.format(**{
            'd': '{d}',
            'd1': '{d1}',
            'name': self.__class__.__name__,
            'ws': '{d1}Strategies: [%s]' % ', '.join([ws.str() for ws in self.strategies]),
            'net': self.net.str(depth=depth+1, max_depth=self.log_detail, **_),
        })
        r = r.replace('{d}', '. '*depth).replace('{d1}', '. '*(depth+1))
        return r

    def config(self, finalize=True, **_) -> dict:
        if finalize:
            return self.net.config(finalize=finalize, **_)
        return super().config(finalize=finalize, **_)

    def named_net_arc_parameters(self) -> (list, list):
        # all named parameters
        net_params, arc_params, duplicate_idx = list(self.net.named_parameters()), [], []
        for ws in self.strategies:
            arc_params += list(ws.named_parameters())
        # remove arc parameters from the network
        for an, ap in arc_params:
            for idx, (n, p) in enumerate(net_params):
                if ap is p:
                    duplicate_idx.append(idx)
        for idx in sorted(duplicate_idx, reverse=True):
            net_params.pop(idx)
        return net_params, arc_params

    def track_used_params(self, x: torch.Tensor) -> Tracker:
        """
        track which weights are used for the current architecture,
        and in which cell
        """
        tracker = Tracker()
        is_train = self.training
        self.eval()
        handles = []
        ws_modules = []
        x = x.to(self.get_device())

        # find all modules that have a weight strategy, add hooks
        for name, module in self.named_modules():
            if hasattr(module, 'ws') and isinstance(module.ws, (AbstractWeightStrategy, StrategyManager)):
                ws_modules.append(module)
                for name2, m2 in module.named_modules():
                    if len(get_to_print(m2)) >= 1:
                        handles.append(m2.register_forward_hook(Hook(tracker, 'net.%s.%s' % (name, name2))))

        # forward pass with the current arc, all used weights are tracked
        self.forward_net(x)

        tracker.finalize()
        for h in handles:
            h.remove()
        self.train(is_train)
        return tracker

    @classmethod
    def get_space_tuple(cls, unique=True, flat=False) -> tuple:
        """ tuple of final topology """
        return tuple(StrategyManager().get_all_finalized_indices(unique=unique, flat=flat))