示例#1
0
    def __init__(self, model, app_args, amc_cfg, services):
        self.pylogger = distiller.data_loggers.PythonLogger(msglogger)
        logdir = logging.getLogger().logdir
        self.tflogger = distiller.data_loggers.TensorBoardLogger(logdir)
        self.verbose = False
        self.orig_model = copy.deepcopy(model)
        self.app_args = app_args
        self.amc_cfg = amc_cfg
        self.services = services

        try:
            modules_list = amc_cfg.modules_dict[app_args.arch]
        except KeyError:
            raise ValueError("The config file does not specify the modules to compress for %s" % app_args.arch)
        self.net_wrapper = NetworkWrapper(model, app_args, services, modules_list, amc_cfg.pruning_pattern)
        self.original_model_macs, self.original_model_size = self.net_wrapper.get_resources_requirements()
        self.reset(init_only=True)
        self._max_episode_steps = self.net_wrapper.model_metadata.num_pruned_layers()  # Hack for Coach-TD3
        self.episode = 0
        self.best_reward = float("-inf")
        self.action_low = amc_cfg.action_range[0]
        self.action_high = amc_cfg.action_range[1]
        self._log_model_info()
        log_amc_config(amc_cfg)
        self._configure_action_space()
        self.observation_space = spaces.Box(0, float("inf"), shape=(len(Observation._fields),))
        self.stats_logger = AMCStatsLogger(os.path.join(logdir, 'amc.csv'))
        self.ft_stats_logger = FineTuneStatsLogger(os.path.join(logdir, 'ft_top1.csv'))

        if self.amc_cfg.pruning_method == "fm-reconstruction":
            self._collect_fm_reconstruction_samples(modules_list)
示例#2
0
class DistillerWrapperEnvironment(gym.Env):
    def __init__(self, model, app_args, amc_cfg, services):
        self.pylogger = distiller.data_loggers.PythonLogger(msglogger)
        logdir = logging.getLogger().logdir
        self.tflogger = distiller.data_loggers.TensorBoardLogger(logdir)
        self.verbose = False
        self.orig_model = copy.deepcopy(model)
        self.app_args = app_args
        self.amc_cfg = amc_cfg
        self.services = services

        try:
            modules_list = amc_cfg.modules_dict[app_args.arch]
        except KeyError:
            raise ValueError("The config file does not specify the modules to compress for %s" % app_args.arch)
        self.net_wrapper = NetworkWrapper(model, app_args, services, modules_list, amc_cfg.pruning_pattern)
        self.original_model_macs, self.original_model_size = self.net_wrapper.get_resources_requirements()
        self.reset(init_only=True)
        self._max_episode_steps = self.net_wrapper.model_metadata.num_pruned_layers()  # Hack for Coach-TD3
        self.episode = 0
        self.best_reward = float("-inf")
        self.action_low = amc_cfg.action_range[0]
        self.action_high = amc_cfg.action_range[1]
        self._log_model_info()
        log_amc_config(amc_cfg)
        self._configure_action_space()
        self.observation_space = spaces.Box(0, float("inf"), shape=(len(Observation._fields),))
        self.stats_logger = AMCStatsLogger(os.path.join(logdir, 'amc.csv'))
        self.ft_stats_logger = FineTuneStatsLogger(os.path.join(logdir, 'ft_top1.csv'))

        if self.amc_cfg.pruning_method == "fm-reconstruction":
            self._collect_fm_reconstruction_samples(modules_list)

    def _collect_fm_reconstruction_samples(self, modules_list):
        """Run the forward-pass on the selected dataset and collect feature-map samples.

        These data will be used when we optimize the compressed-net's weights by trying
        to reconstruct these samples.
        """
        from functools import partial
        if self.amc_cfg.pruning_pattern != "channels":
            raise ValueError("Feature-map reconstruction is only supported when pruning weights channels")

        def acceptance_criterion(m, mod_names):
            # Collect feature-maps only for Conv2d layers, if they are in our modules list.
            return isinstance(m, torch.nn.Conv2d) and m.distiller_name in mod_names

        # For feature-map reconstruction we need to collect a representative set
        # of inter-layer feature-maps
        from distiller.pruning import FMReconstructionChannelPruner
        collect_intermediate_featuremap_samples(
            self.net_wrapper.model,
            self.net_wrapper.validate,
            partial(acceptance_criterion, mod_names=modules_list),
            partial(FMReconstructionChannelPruner.cache_featuremaps_fwd_hook,
                    n_points_per_fm=self.amc_cfg.n_points_per_fm))

    def _log_model_info(self):
        msglogger.debug("Model %s has %d modules (%d pruned)", self.app_args.arch,
                                                               self.net_wrapper.model_metadata.num_layers(),
                                                               self.net_wrapper.model_metadata.num_pruned_layers())
        msglogger.debug("\tTotal MACs: %s" % distiller.pretty_int(self.original_model_macs))
        msglogger.debug("\tTotal weights: %s" % distiller.pretty_int(self.original_model_size))

    def _configure_action_space(self):
        if is_using_continuous_action_space(self.amc_cfg.agent_algo):
            if self.amc_cfg.agent_algo == "ClippedPPO-continuous":
                self.action_space = spaces.Box(PPO_MIN, PPO_MAX, shape=(1,))
            else:
                self.action_space = spaces.Box(self.action_low, self.action_high, shape=(1,))
            self.action_space.default_action = self.action_low
        else:
            self.action_space = spaces.Discrete(10)


    @property
    def steps_per_episode(self):
        return self.net_wrapper.model_metadata.num_pruned_layers() 
        
    def reset(self, init_only=False):
        """Reset the environment.
        This is invoked by the Agent.
        """
        msglogger.info("Resetting the environment (init_only={})".format(init_only))
        self.current_state_id = 0
        self.current_layer_id = self.net_wrapper.model_metadata.pruned_idxs[self.current_state_id]
        self.prev_action = 0
        self.model = copy.deepcopy(self.orig_model)
        if hasattr(self.net_wrapper.model, 'intermediate_fms'):
            self.model.intermediate_fms = self.net_wrapper.model.intermediate_fms
        self.net_wrapper.reset(self.model)
        self._removed_macs = 0
        self.action_history = []
        self.agent_action_history = []
        self.model_representation = self.get_model_representation()
        if init_only:
            return
        initial_observation = self.get_obs()
        return initial_observation

    def current_layer(self):
        return self.net_wrapper.get_pruned_layer(self.current_layer_id)

    def episode_is_done(self):
        return self.current_state_id == self.net_wrapper.model_metadata.num_pruned_layers()

    @property
    def removed_macs_pct(self):
        """Return the amount of MACs removed so far.
        This is normalized to the range 0..1
        """
        return self._removed_macs / self.original_model_macs

    def render(self, mode='human'):
        """Provide some feedback to the user about what's going on.
        This is invoked by the Agent.
        """
        if self.current_state_id == 0:
            msglogger.info("+" + "-" * 50 + "+")
            msglogger.info("Starting a new episode %d", self.episode)
            msglogger.info("+" + "-" * 50 + "+")
        if not self.verbose:
            return
        msglogger.info("Render Environment: current_state_id=%d" % self.current_state_id)
        distiller.log_weights_sparsity(self.model, -1, loggers=[self.pylogger])

    def step(self, pruning_action):
        """Take a step, given an action.

        The action represents the desired sparsity for the "current" layer (i.e. the percentage of weights to remove).
        This function is invoked by the Agent.
        """
        pruning_action = float(pruning_action[0])
        msglogger.debug("env.step - current_state_id=%d (%s) episode=%d action=%.2f" %
                        (self.current_state_id, self.current_layer().name, self.episode, pruning_action))
        self.agent_action_history.append(pruning_action)

        if is_using_continuous_action_space(self.amc_cfg.agent_algo):
            if self.amc_cfg.agent_algo == "ClippedPPO-continuous":
                # We need to map PPO's infinite action-space (actions sampled from a Gaussian) to our action-space.
                pruning_action = adjust_ppo_output(pruning_action, self.action_high, self.action_low)
            else:
                pruning_action = np.clip(pruning_action, self.action_low, self.action_high)
        else:
            # Divide the action space into 10 discrete levels (0%, 10%, 20%,....90% sparsity)
            pruning_action = pruning_action / 10
        msglogger.debug("\tAgent clipped pruning_action={}".format(pruning_action))

        if self.amc_cfg.action_constrain_fn is not None:
            pruning_action = self.amc_cfg.action_constrain_fn(self, pruning_action=pruning_action)
            msglogger.debug("Constrained pruning_action={}".format(pruning_action))

        # Calculate the final compression rate
        total_macs_before, _ = self.net_wrapper.get_resources_requirements()
        layer_macs = self.net_wrapper.layer_macs(self.current_layer())
        msglogger.debug("\tlayer_macs={:.2f}".format(layer_macs / self.original_model_macs))
        msglogger.debug("\tremoved_macs={:.2f}".format(self.removed_macs_pct))
        msglogger.debug("\trest_macs={:.2f}".format(self.rest_macs()))
        msglogger.debug("\tcurrent_layer_id = %d" % self.current_layer_id)
        self.current_state_id += 1
        if pruning_action > 0:
            pruning_action = self.net_wrapper.remove_structures(self.current_layer_id,
                                                                fraction_to_prune=pruning_action,
                                                                prune_what=self.amc_cfg.pruning_pattern,
                                                                prune_how=self.amc_cfg.pruning_method,
                                                                group_size=self.amc_cfg.group_size,
                                                                apply_thinning=self.episode_is_done(),
                                                                ranking_noise=self.amc_cfg.ranking_noise)
                                                                #random_state=self.random_state)
        else:
            pruning_action = 0

        self.action_history.append(pruning_action)
        total_macs_after_act, total_nnz_after_act = self.net_wrapper.get_resources_requirements()
        layer_macs_after_action = self.net_wrapper.layer_macs(self.current_layer())

        # Update the various counters after taking the step
        self._removed_macs += (total_macs_before - total_macs_after_act)

        msglogger.debug("\tactual_action={}".format(pruning_action))
        msglogger.debug("\tlayer_macs={} layer_macs_after_action={} removed now={}".format(layer_macs,
                                                                                        layer_macs_after_action,
                                                                                        (layer_macs - layer_macs_after_action)))
        msglogger.debug("\tself._removed_macs={}".format(self._removed_macs))
        assert math.isclose(layer_macs_after_action / layer_macs, 1 - pruning_action)

        stats = ('Performance/Validation/',
                 OrderedDict([('requested_action', pruning_action)]))

        distiller.log_training_progress(stats, None,
                                        self.episode, steps_completed=self.current_state_id,
                                        total_steps=self.net_wrapper.num_pruned_layers(), log_freq=1, loggers=[self.tflogger])

        if self.episode_is_done():
            msglogger.info("Episode is ending")
            observation = self.get_final_obs()
            reward, top1 = self.compute_reward(total_macs_after_act, total_nnz_after_act)
            normalized_macs = total_macs_after_act / self.original_model_macs * 100
            normalized_nnz = total_nnz_after_act / self.original_model_size * 100
            self.finalize_episode(top1, reward, total_macs_after_act, normalized_macs,
                                  normalized_nnz, self.action_history, self.agent_action_history)
            self.episode += 1
        else:
            self.current_layer_id = self.net_wrapper.model_metadata.pruned_idxs[self.current_state_id]

            if self.amc_cfg.ft_frequency is not None and self.current_state_id % self.amc_cfg.ft_frequency == 0:
                self.net_wrapper.train(1, self.episode)
            observation = self.get_obs()
            if self.amc_cfg.reward_frequency is not None and self.current_state_id % self.amc_cfg.reward_frequency == 0:
                reward, top1 = self.compute_reward(total_macs_after_act, total_nnz_after_act, log_stats=False)
            else:
                reward = 0
        self.prev_action = pruning_action
        if self.episode_is_done():
            info = {"accuracy": top1, "compress_ratio": normalized_macs}
            msglogger.info(self.removed_macs_pct)
            if self.amc_cfg.protocol == "mac-constrained":
                # Sanity check (special case only for "mac-constrained")
                #assert self.removed_macs_pct >= 1 - self.amc_cfg.target_density - 0.01
                pass
        else:
            info = {}
        return observation, reward, self.episode_is_done(), info

    def get_obs(self):
        """Produce a state embedding (i.e. an observation)"""
        current_layer_macs = self.net_wrapper.layer_net_macs(self.current_layer())
        current_layer_macs_pct = current_layer_macs/self.original_model_macs
        current_layer = self.current_layer()
        conv_module = distiller.model_find_module(self.model, current_layer.name)

        obs = self.model_representation[self.current_state_id, :]
        obs[-1] = self.prev_action
        obs[-2] = self.rest_macs()
        obs[-3] = self.removed_macs_pct
        msglogger.debug("obs={}".format(Observation._make(obs)))
        # Sanity check
        assert (self.removed_macs_pct + current_layer_macs_pct + self.rest_macs()) <= 1
        return obs

    def get_final_obs(self):
        """Return the final state embedding (observation).

        The final state is reached after we traverse all of the Convolution layers.
        """
        obs = self.model_representation[-1, :]
        msglogger.debug("obs={}".format(Observation._make(obs)))
        return obs

    def get_model_representation(self):
        """Initialize an embedding representation of the entire model.

        At runtime, a specific row in the embedding matrix is chosen (depending on
        the current state) and the dynamic fields in the resulting state-embedding
        vector are updated. 
        """
        num_states = self.net_wrapper.num_pruned_layers()
        network_obs = np.empty(shape=(num_states, ObservationLen))
        for state_id, layer_id in enumerate(self.net_wrapper.model_metadata.pruned_idxs):
            layer = self.net_wrapper.get_layer(layer_id)
            layer_macs = self.net_wrapper.layer_macs(layer)
            conv_module = distiller.model_find_module(self.model, layer.name)
            obs = [state_id,
                   conv_module.out_channels,
                   conv_module.in_channels,
                   layer.ifm_h,
                   layer.ifm_w,
                   layer.stride[0],
                   layer.k,
                   distiller.volume(conv_module.weight),
                   layer_macs,
                   0, 0, 0]
            network_obs[state_id:] = np.array(obs)

        # Feature normalization
        for feature in range(ObservationLen):
            feature_vec = network_obs[:, feature]
            fmin = min(feature_vec)
            fmax = max(feature_vec)
            if fmax - fmin > 0:
                network_obs[:, feature] = (feature_vec - fmin) / (fmax - fmin)
        # msglogger.debug("model representation=\n{}".format(network_obs))
        return network_obs

    def rest_macs_raw(self):
        """Return the number of remaining MACs in the layers following the current layer"""
        rest, prunable_rest = 0, 0
        prunable_layers, rest_layers, layers_to_ignore = list(), list(), list()

        # Create a list of the IDs of the layers that are dependent on the current_layer.
        # We want to ignore these layers when we compute prunable_layers (and prunable_rest).
        for dependent_mod in self.current_layer().dependencies:
            layers_to_ignore.append(self.net_wrapper.name2layer(dependent_mod).id)

        for layer_id in range(self.current_layer_id+1, self.net_wrapper.model_metadata.num_layers()):
            layer_macs = self.net_wrapper.layer_net_macs(self.net_wrapper.get_layer(layer_id))
            if self.net_wrapper.model_metadata.is_reducible(layer_id):
                if layer_id not in layers_to_ignore:
                    prunable_layers.append((layer_id, self.net_wrapper.get_layer(layer_id).name, layer_macs))
                    prunable_rest += layer_macs
            else:
                rest_layers.append((layer_id, self.net_wrapper.get_layer(layer_id).name, layer_macs))
                rest += layer_macs

        msglogger.debug("prunable_layers={} rest_layers={}".format(prunable_layers, rest_layers))
        msglogger.debug("layer_id=%d, prunable_rest=%.3f rest=%.3f" % (self.current_layer_id, prunable_rest, rest))
        return prunable_rest, rest

    def rest_macs(self):
        return sum(self.rest_macs_raw()) / self.original_model_macs

    def is_macs_constraint_achieved(self, compressed_model_total_macs):
        current_density = compressed_model_total_macs / self.original_model_macs
        return self.amc_cfg.target_density >= current_density

    def compute_reward(self, total_macs, total_nnz, log_stats=True):
        """Compute the reward.

        We use the validation dataset (the size of the validation dataset is
        configured when the data-loader is instantiated)"""
        distiller.log_weights_sparsity(self.model, -1, loggers=[self.pylogger])
        compression = distiller.model_numel(self.model, param_dims=[4]) / self.original_model_size

        # Fine-tune (this is a nop if self.amc_cfg.num_ft_epochs==0)
        accuracies = self.net_wrapper.train(self.amc_cfg.num_ft_epochs, self.episode)
        self.ft_stats_logger.add_record([self.episode, accuracies])

        top1, top5, vloss = self.net_wrapper.validate()
        reward = self.amc_cfg.reward_fn(self, top1, top5, vloss, total_macs)

        if log_stats:
            macs_normalized = total_macs/self.original_model_macs
            msglogger.info("Total parameters left: %.2f%%" % (compression*100))
            msglogger.info("Total compute left: %.2f%%" % (total_macs/self.original_model_macs*100))

            stats = ('Performance/EpisodeEnd/',
                     OrderedDict([('Loss', vloss),
                                  ('Top1', top1),
                                  ('Top5', top5),
                                  ('reward', reward),
                                  ('total_macs', int(total_macs)),
                                  ('macs_normalized', macs_normalized*100),
                                  ('log(total_macs)', math.log(total_macs)),
                                  ('total_nnz', int(total_nnz))]))
            distiller.log_training_progress(stats, None, self.episode, steps_completed=0, total_steps=1,
                                            log_freq=1, loggers=[self.tflogger, self.pylogger])
        return reward, top1

    def finalize_episode(self, top1, reward, total_macs, normalized_macs,
                         normalized_nnz, action_history, agent_action_history):
        """Write the details of one network to the logger and create a checkpoint file"""
        if reward > self.best_reward:
            self.best_reward = reward
            ckpt_name = self.save_checkpoint(is_best=True)
            msglogger.info("Best reward={}  episode={}  top1={}".format(reward, self.episode, top1))
        else:
            ckpt_name = self.save_checkpoint(is_best=False)

        import json
        performance = self.net_wrapper.performance_summary()
        fields = [self.episode, top1, reward, total_macs, normalized_macs, normalized_nnz,
                  ckpt_name, json.dumps(action_history), json.dumps(agent_action_history),
                  json.dumps(performance)]
        self.stats_logger.add_record(fields)

    def save_checkpoint(self, is_best=False):
        """Save the learned-model checkpoint"""
        episode = str(self.episode).zfill(3)
        if is_best:
            fname = "BEST_adc_episode_{}".format(episode)
        else:
            fname = "adc_episode_{}".format(episode)
        if is_best or self.amc_cfg.save_chkpts:
            # Always save the best episodes, and depending on amc_cfg.save_chkpts save all other episodes
            scheduler = self.net_wrapper.create_scheduler()
            self.services.save_checkpoint_fn(epoch=0, model=self.model,
                                            scheduler=scheduler, name=fname)
            del scheduler
        return fname
示例#3
0
    def __init__(self, model, app_args, amc_cfg, services):
        self.pylogger = distiller.data_loggers.PythonLogger(msglogger)
        logdir = logging.getLogger().logdir
        self.tflogger = distiller.data_loggers.TensorBoardLogger(logdir)
        self.verbose = False
        self.orig_model = copy.deepcopy(model)
        self.app_args = app_args
        self.amc_cfg = amc_cfg
        self.services = services

        try:
            modules_list = amc_cfg.modules_dict[app_args.arch]
        except KeyError:
            msglogger.warning(
                "!!! The config file does not specify the modules to compress for %s"
                % app_args.arch)
            # Default to using all convolution layers
            distiller.assign_layer_fq_names(model)
            modules_list = [
                mod.distiller_name for mod in model.modules()
                if type(mod) == torch.nn.Conv2d
            ]
            msglogger.warning("Using the following layers: %s" %
                              ", ".join(modules_list))

        self.net_wrapper = NetworkWrapper(model, app_args, services,
                                          modules_list,
                                          amc_cfg.pruning_pattern)
        self.original_model_macs, self.original_model_size = self.net_wrapper.get_resources_requirements(
        )
        self.reset(init_only=True)
        msglogger.debug("Model %s has %d modules (%d pruned)",
                        self.app_args.arch,
                        self.net_wrapper.model_metadata.num_layers(),
                        self.net_wrapper.model_metadata.num_pruned_layers())
        msglogger.debug("\tTotal MACs: %s" %
                        distiller.pretty_int(self.original_model_macs))
        msglogger.debug("\tTotal weights: %s" %
                        distiller.pretty_int(self.original_model_size))
        self._max_episode_steps = self.net_wrapper.model_metadata.num_pruned_layers(
        )  # Hack for Coach-TD3
        log_amc_config(amc_cfg)

        self.episode = 0
        self.best_reward = float("-inf")
        self.action_low = amc_cfg.action_range[0]
        self.action_high = amc_cfg.action_range[1]

        if is_using_continuous_action_space(self.amc_cfg.agent_algo):
            if self.amc_cfg.agent_algo == "ClippedPPO-continuous":
                self.action_space = spaces.Box(PPO_MIN, PPO_MAX, shape=(1, ))
            else:
                self.action_space = spaces.Box(self.action_low,
                                               self.action_high,
                                               shape=(1, ))
            self.action_space.default_action = self.action_low
        else:
            self.action_space = spaces.Discrete(10)
        self.observation_space = spaces.Box(0,
                                            float("inf"),
                                            shape=(len(Observation._fields), ))
        self.stats_logger = AMCStatsLogger(os.path.join(logdir, 'amc.csv'))
        self.ft_stats_logger = FineTuneStatsLogger(
            os.path.join(logdir, 'ft_top1.csv'))

        if self.amc_cfg.pruning_method == "fm-reconstruction":
            if self.amc_cfg.pruning_pattern != "channels":
                raise ValueError(
                    "Feature-map reconstruction is only supported when pruning weights channels"
                )

            from functools import partial

            def acceptance_criterion(m, mod_names):
                # Collect feature-maps only for Conv2d layers, if they are in our modules list.
                return isinstance(
                    m, torch.nn.Conv2d) and m.distiller_name in mod_names

            # For feature-map reconstruction we need to collect a representative set
            # of inter-layer feature-maps
            from distiller.pruning import FMReconstructionChannelPruner
            collect_intermediate_featuremap_samples(
                self.net_wrapper.model, self.net_wrapper.validate,
                partial(acceptance_criterion, mod_names=modules_list),
                partial(
                    FMReconstructionChannelPruner.cache_featuremaps_fwd_hook,
                    n_points_per_fm=self.amc_cfg.n_points_per_fm))