def __init__(self, model, app_args, amc_cfg, services): self.pylogger = distiller.data_loggers.PythonLogger(msglogger) logdir = logging.getLogger().logdir self.tflogger = distiller.data_loggers.TensorBoardLogger(logdir) self.verbose = False self.orig_model = copy.deepcopy(model) self.app_args = app_args self.amc_cfg = amc_cfg self.services = services try: modules_list = amc_cfg.modules_dict[app_args.arch] except KeyError: raise ValueError("The config file does not specify the modules to compress for %s" % app_args.arch) self.net_wrapper = NetworkWrapper(model, app_args, services, modules_list, amc_cfg.pruning_pattern) self.original_model_macs, self.original_model_size = self.net_wrapper.get_resources_requirements() self.reset(init_only=True) self._max_episode_steps = self.net_wrapper.model_metadata.num_pruned_layers() # Hack for Coach-TD3 self.episode = 0 self.best_reward = float("-inf") self.action_low = amc_cfg.action_range[0] self.action_high = amc_cfg.action_range[1] self._log_model_info() log_amc_config(amc_cfg) self._configure_action_space() self.observation_space = spaces.Box(0, float("inf"), shape=(len(Observation._fields),)) self.stats_logger = AMCStatsLogger(os.path.join(logdir, 'amc.csv')) self.ft_stats_logger = FineTuneStatsLogger(os.path.join(logdir, 'ft_top1.csv')) if self.amc_cfg.pruning_method == "fm-reconstruction": self._collect_fm_reconstruction_samples(modules_list)
class DistillerWrapperEnvironment(gym.Env): def __init__(self, model, app_args, amc_cfg, services): self.pylogger = distiller.data_loggers.PythonLogger(msglogger) logdir = logging.getLogger().logdir self.tflogger = distiller.data_loggers.TensorBoardLogger(logdir) self.verbose = False self.orig_model = copy.deepcopy(model) self.app_args = app_args self.amc_cfg = amc_cfg self.services = services try: modules_list = amc_cfg.modules_dict[app_args.arch] except KeyError: raise ValueError("The config file does not specify the modules to compress for %s" % app_args.arch) self.net_wrapper = NetworkWrapper(model, app_args, services, modules_list, amc_cfg.pruning_pattern) self.original_model_macs, self.original_model_size = self.net_wrapper.get_resources_requirements() self.reset(init_only=True) self._max_episode_steps = self.net_wrapper.model_metadata.num_pruned_layers() # Hack for Coach-TD3 self.episode = 0 self.best_reward = float("-inf") self.action_low = amc_cfg.action_range[0] self.action_high = amc_cfg.action_range[1] self._log_model_info() log_amc_config(amc_cfg) self._configure_action_space() self.observation_space = spaces.Box(0, float("inf"), shape=(len(Observation._fields),)) self.stats_logger = AMCStatsLogger(os.path.join(logdir, 'amc.csv')) self.ft_stats_logger = FineTuneStatsLogger(os.path.join(logdir, 'ft_top1.csv')) if self.amc_cfg.pruning_method == "fm-reconstruction": self._collect_fm_reconstruction_samples(modules_list) def _collect_fm_reconstruction_samples(self, modules_list): """Run the forward-pass on the selected dataset and collect feature-map samples. These data will be used when we optimize the compressed-net's weights by trying to reconstruct these samples. """ from functools import partial if self.amc_cfg.pruning_pattern != "channels": raise ValueError("Feature-map reconstruction is only supported when pruning weights channels") def acceptance_criterion(m, mod_names): # Collect feature-maps only for Conv2d layers, if they are in our modules list. return isinstance(m, torch.nn.Conv2d) and m.distiller_name in mod_names # For feature-map reconstruction we need to collect a representative set # of inter-layer feature-maps from distiller.pruning import FMReconstructionChannelPruner collect_intermediate_featuremap_samples( self.net_wrapper.model, self.net_wrapper.validate, partial(acceptance_criterion, mod_names=modules_list), partial(FMReconstructionChannelPruner.cache_featuremaps_fwd_hook, n_points_per_fm=self.amc_cfg.n_points_per_fm)) def _log_model_info(self): msglogger.debug("Model %s has %d modules (%d pruned)", self.app_args.arch, self.net_wrapper.model_metadata.num_layers(), self.net_wrapper.model_metadata.num_pruned_layers()) msglogger.debug("\tTotal MACs: %s" % distiller.pretty_int(self.original_model_macs)) msglogger.debug("\tTotal weights: %s" % distiller.pretty_int(self.original_model_size)) def _configure_action_space(self): if is_using_continuous_action_space(self.amc_cfg.agent_algo): if self.amc_cfg.agent_algo == "ClippedPPO-continuous": self.action_space = spaces.Box(PPO_MIN, PPO_MAX, shape=(1,)) else: self.action_space = spaces.Box(self.action_low, self.action_high, shape=(1,)) self.action_space.default_action = self.action_low else: self.action_space = spaces.Discrete(10) @property def steps_per_episode(self): return self.net_wrapper.model_metadata.num_pruned_layers() def reset(self, init_only=False): """Reset the environment. This is invoked by the Agent. """ msglogger.info("Resetting the environment (init_only={})".format(init_only)) self.current_state_id = 0 self.current_layer_id = self.net_wrapper.model_metadata.pruned_idxs[self.current_state_id] self.prev_action = 0 self.model = copy.deepcopy(self.orig_model) if hasattr(self.net_wrapper.model, 'intermediate_fms'): self.model.intermediate_fms = self.net_wrapper.model.intermediate_fms self.net_wrapper.reset(self.model) self._removed_macs = 0 self.action_history = [] self.agent_action_history = [] self.model_representation = self.get_model_representation() if init_only: return initial_observation = self.get_obs() return initial_observation def current_layer(self): return self.net_wrapper.get_pruned_layer(self.current_layer_id) def episode_is_done(self): return self.current_state_id == self.net_wrapper.model_metadata.num_pruned_layers() @property def removed_macs_pct(self): """Return the amount of MACs removed so far. This is normalized to the range 0..1 """ return self._removed_macs / self.original_model_macs def render(self, mode='human'): """Provide some feedback to the user about what's going on. This is invoked by the Agent. """ if self.current_state_id == 0: msglogger.info("+" + "-" * 50 + "+") msglogger.info("Starting a new episode %d", self.episode) msglogger.info("+" + "-" * 50 + "+") if not self.verbose: return msglogger.info("Render Environment: current_state_id=%d" % self.current_state_id) distiller.log_weights_sparsity(self.model, -1, loggers=[self.pylogger]) def step(self, pruning_action): """Take a step, given an action. The action represents the desired sparsity for the "current" layer (i.e. the percentage of weights to remove). This function is invoked by the Agent. """ pruning_action = float(pruning_action[0]) msglogger.debug("env.step - current_state_id=%d (%s) episode=%d action=%.2f" % (self.current_state_id, self.current_layer().name, self.episode, pruning_action)) self.agent_action_history.append(pruning_action) if is_using_continuous_action_space(self.amc_cfg.agent_algo): if self.amc_cfg.agent_algo == "ClippedPPO-continuous": # We need to map PPO's infinite action-space (actions sampled from a Gaussian) to our action-space. pruning_action = adjust_ppo_output(pruning_action, self.action_high, self.action_low) else: pruning_action = np.clip(pruning_action, self.action_low, self.action_high) else: # Divide the action space into 10 discrete levels (0%, 10%, 20%,....90% sparsity) pruning_action = pruning_action / 10 msglogger.debug("\tAgent clipped pruning_action={}".format(pruning_action)) if self.amc_cfg.action_constrain_fn is not None: pruning_action = self.amc_cfg.action_constrain_fn(self, pruning_action=pruning_action) msglogger.debug("Constrained pruning_action={}".format(pruning_action)) # Calculate the final compression rate total_macs_before, _ = self.net_wrapper.get_resources_requirements() layer_macs = self.net_wrapper.layer_macs(self.current_layer()) msglogger.debug("\tlayer_macs={:.2f}".format(layer_macs / self.original_model_macs)) msglogger.debug("\tremoved_macs={:.2f}".format(self.removed_macs_pct)) msglogger.debug("\trest_macs={:.2f}".format(self.rest_macs())) msglogger.debug("\tcurrent_layer_id = %d" % self.current_layer_id) self.current_state_id += 1 if pruning_action > 0: pruning_action = self.net_wrapper.remove_structures(self.current_layer_id, fraction_to_prune=pruning_action, prune_what=self.amc_cfg.pruning_pattern, prune_how=self.amc_cfg.pruning_method, group_size=self.amc_cfg.group_size, apply_thinning=self.episode_is_done(), ranking_noise=self.amc_cfg.ranking_noise) #random_state=self.random_state) else: pruning_action = 0 self.action_history.append(pruning_action) total_macs_after_act, total_nnz_after_act = self.net_wrapper.get_resources_requirements() layer_macs_after_action = self.net_wrapper.layer_macs(self.current_layer()) # Update the various counters after taking the step self._removed_macs += (total_macs_before - total_macs_after_act) msglogger.debug("\tactual_action={}".format(pruning_action)) msglogger.debug("\tlayer_macs={} layer_macs_after_action={} removed now={}".format(layer_macs, layer_macs_after_action, (layer_macs - layer_macs_after_action))) msglogger.debug("\tself._removed_macs={}".format(self._removed_macs)) assert math.isclose(layer_macs_after_action / layer_macs, 1 - pruning_action) stats = ('Performance/Validation/', OrderedDict([('requested_action', pruning_action)])) distiller.log_training_progress(stats, None, self.episode, steps_completed=self.current_state_id, total_steps=self.net_wrapper.num_pruned_layers(), log_freq=1, loggers=[self.tflogger]) if self.episode_is_done(): msglogger.info("Episode is ending") observation = self.get_final_obs() reward, top1 = self.compute_reward(total_macs_after_act, total_nnz_after_act) normalized_macs = total_macs_after_act / self.original_model_macs * 100 normalized_nnz = total_nnz_after_act / self.original_model_size * 100 self.finalize_episode(top1, reward, total_macs_after_act, normalized_macs, normalized_nnz, self.action_history, self.agent_action_history) self.episode += 1 else: self.current_layer_id = self.net_wrapper.model_metadata.pruned_idxs[self.current_state_id] if self.amc_cfg.ft_frequency is not None and self.current_state_id % self.amc_cfg.ft_frequency == 0: self.net_wrapper.train(1, self.episode) observation = self.get_obs() if self.amc_cfg.reward_frequency is not None and self.current_state_id % self.amc_cfg.reward_frequency == 0: reward, top1 = self.compute_reward(total_macs_after_act, total_nnz_after_act, log_stats=False) else: reward = 0 self.prev_action = pruning_action if self.episode_is_done(): info = {"accuracy": top1, "compress_ratio": normalized_macs} msglogger.info(self.removed_macs_pct) if self.amc_cfg.protocol == "mac-constrained": # Sanity check (special case only for "mac-constrained") #assert self.removed_macs_pct >= 1 - self.amc_cfg.target_density - 0.01 pass else: info = {} return observation, reward, self.episode_is_done(), info def get_obs(self): """Produce a state embedding (i.e. an observation)""" current_layer_macs = self.net_wrapper.layer_net_macs(self.current_layer()) current_layer_macs_pct = current_layer_macs/self.original_model_macs current_layer = self.current_layer() conv_module = distiller.model_find_module(self.model, current_layer.name) obs = self.model_representation[self.current_state_id, :] obs[-1] = self.prev_action obs[-2] = self.rest_macs() obs[-3] = self.removed_macs_pct msglogger.debug("obs={}".format(Observation._make(obs))) # Sanity check assert (self.removed_macs_pct + current_layer_macs_pct + self.rest_macs()) <= 1 return obs def get_final_obs(self): """Return the final state embedding (observation). The final state is reached after we traverse all of the Convolution layers. """ obs = self.model_representation[-1, :] msglogger.debug("obs={}".format(Observation._make(obs))) return obs def get_model_representation(self): """Initialize an embedding representation of the entire model. At runtime, a specific row in the embedding matrix is chosen (depending on the current state) and the dynamic fields in the resulting state-embedding vector are updated. """ num_states = self.net_wrapper.num_pruned_layers() network_obs = np.empty(shape=(num_states, ObservationLen)) for state_id, layer_id in enumerate(self.net_wrapper.model_metadata.pruned_idxs): layer = self.net_wrapper.get_layer(layer_id) layer_macs = self.net_wrapper.layer_macs(layer) conv_module = distiller.model_find_module(self.model, layer.name) obs = [state_id, conv_module.out_channels, conv_module.in_channels, layer.ifm_h, layer.ifm_w, layer.stride[0], layer.k, distiller.volume(conv_module.weight), layer_macs, 0, 0, 0] network_obs[state_id:] = np.array(obs) # Feature normalization for feature in range(ObservationLen): feature_vec = network_obs[:, feature] fmin = min(feature_vec) fmax = max(feature_vec) if fmax - fmin > 0: network_obs[:, feature] = (feature_vec - fmin) / (fmax - fmin) # msglogger.debug("model representation=\n{}".format(network_obs)) return network_obs def rest_macs_raw(self): """Return the number of remaining MACs in the layers following the current layer""" rest, prunable_rest = 0, 0 prunable_layers, rest_layers, layers_to_ignore = list(), list(), list() # Create a list of the IDs of the layers that are dependent on the current_layer. # We want to ignore these layers when we compute prunable_layers (and prunable_rest). for dependent_mod in self.current_layer().dependencies: layers_to_ignore.append(self.net_wrapper.name2layer(dependent_mod).id) for layer_id in range(self.current_layer_id+1, self.net_wrapper.model_metadata.num_layers()): layer_macs = self.net_wrapper.layer_net_macs(self.net_wrapper.get_layer(layer_id)) if self.net_wrapper.model_metadata.is_reducible(layer_id): if layer_id not in layers_to_ignore: prunable_layers.append((layer_id, self.net_wrapper.get_layer(layer_id).name, layer_macs)) prunable_rest += layer_macs else: rest_layers.append((layer_id, self.net_wrapper.get_layer(layer_id).name, layer_macs)) rest += layer_macs msglogger.debug("prunable_layers={} rest_layers={}".format(prunable_layers, rest_layers)) msglogger.debug("layer_id=%d, prunable_rest=%.3f rest=%.3f" % (self.current_layer_id, prunable_rest, rest)) return prunable_rest, rest def rest_macs(self): return sum(self.rest_macs_raw()) / self.original_model_macs def is_macs_constraint_achieved(self, compressed_model_total_macs): current_density = compressed_model_total_macs / self.original_model_macs return self.amc_cfg.target_density >= current_density def compute_reward(self, total_macs, total_nnz, log_stats=True): """Compute the reward. We use the validation dataset (the size of the validation dataset is configured when the data-loader is instantiated)""" distiller.log_weights_sparsity(self.model, -1, loggers=[self.pylogger]) compression = distiller.model_numel(self.model, param_dims=[4]) / self.original_model_size # Fine-tune (this is a nop if self.amc_cfg.num_ft_epochs==0) accuracies = self.net_wrapper.train(self.amc_cfg.num_ft_epochs, self.episode) self.ft_stats_logger.add_record([self.episode, accuracies]) top1, top5, vloss = self.net_wrapper.validate() reward = self.amc_cfg.reward_fn(self, top1, top5, vloss, total_macs) if log_stats: macs_normalized = total_macs/self.original_model_macs msglogger.info("Total parameters left: %.2f%%" % (compression*100)) msglogger.info("Total compute left: %.2f%%" % (total_macs/self.original_model_macs*100)) stats = ('Performance/EpisodeEnd/', OrderedDict([('Loss', vloss), ('Top1', top1), ('Top5', top5), ('reward', reward), ('total_macs', int(total_macs)), ('macs_normalized', macs_normalized*100), ('log(total_macs)', math.log(total_macs)), ('total_nnz', int(total_nnz))])) distiller.log_training_progress(stats, None, self.episode, steps_completed=0, total_steps=1, log_freq=1, loggers=[self.tflogger, self.pylogger]) return reward, top1 def finalize_episode(self, top1, reward, total_macs, normalized_macs, normalized_nnz, action_history, agent_action_history): """Write the details of one network to the logger and create a checkpoint file""" if reward > self.best_reward: self.best_reward = reward ckpt_name = self.save_checkpoint(is_best=True) msglogger.info("Best reward={} episode={} top1={}".format(reward, self.episode, top1)) else: ckpt_name = self.save_checkpoint(is_best=False) import json performance = self.net_wrapper.performance_summary() fields = [self.episode, top1, reward, total_macs, normalized_macs, normalized_nnz, ckpt_name, json.dumps(action_history), json.dumps(agent_action_history), json.dumps(performance)] self.stats_logger.add_record(fields) def save_checkpoint(self, is_best=False): """Save the learned-model checkpoint""" episode = str(self.episode).zfill(3) if is_best: fname = "BEST_adc_episode_{}".format(episode) else: fname = "adc_episode_{}".format(episode) if is_best or self.amc_cfg.save_chkpts: # Always save the best episodes, and depending on amc_cfg.save_chkpts save all other episodes scheduler = self.net_wrapper.create_scheduler() self.services.save_checkpoint_fn(epoch=0, model=self.model, scheduler=scheduler, name=fname) del scheduler return fname
def __init__(self, model, app_args, amc_cfg, services): self.pylogger = distiller.data_loggers.PythonLogger(msglogger) logdir = logging.getLogger().logdir self.tflogger = distiller.data_loggers.TensorBoardLogger(logdir) self.verbose = False self.orig_model = copy.deepcopy(model) self.app_args = app_args self.amc_cfg = amc_cfg self.services = services try: modules_list = amc_cfg.modules_dict[app_args.arch] except KeyError: msglogger.warning( "!!! The config file does not specify the modules to compress for %s" % app_args.arch) # Default to using all convolution layers distiller.assign_layer_fq_names(model) modules_list = [ mod.distiller_name for mod in model.modules() if type(mod) == torch.nn.Conv2d ] msglogger.warning("Using the following layers: %s" % ", ".join(modules_list)) self.net_wrapper = NetworkWrapper(model, app_args, services, modules_list, amc_cfg.pruning_pattern) self.original_model_macs, self.original_model_size = self.net_wrapper.get_resources_requirements( ) self.reset(init_only=True) msglogger.debug("Model %s has %d modules (%d pruned)", self.app_args.arch, self.net_wrapper.model_metadata.num_layers(), self.net_wrapper.model_metadata.num_pruned_layers()) msglogger.debug("\tTotal MACs: %s" % distiller.pretty_int(self.original_model_macs)) msglogger.debug("\tTotal weights: %s" % distiller.pretty_int(self.original_model_size)) self._max_episode_steps = self.net_wrapper.model_metadata.num_pruned_layers( ) # Hack for Coach-TD3 log_amc_config(amc_cfg) self.episode = 0 self.best_reward = float("-inf") self.action_low = amc_cfg.action_range[0] self.action_high = amc_cfg.action_range[1] if is_using_continuous_action_space(self.amc_cfg.agent_algo): if self.amc_cfg.agent_algo == "ClippedPPO-continuous": self.action_space = spaces.Box(PPO_MIN, PPO_MAX, shape=(1, )) else: self.action_space = spaces.Box(self.action_low, self.action_high, shape=(1, )) self.action_space.default_action = self.action_low else: self.action_space = spaces.Discrete(10) self.observation_space = spaces.Box(0, float("inf"), shape=(len(Observation._fields), )) self.stats_logger = AMCStatsLogger(os.path.join(logdir, 'amc.csv')) self.ft_stats_logger = FineTuneStatsLogger( os.path.join(logdir, 'ft_top1.csv')) if self.amc_cfg.pruning_method == "fm-reconstruction": if self.amc_cfg.pruning_pattern != "channels": raise ValueError( "Feature-map reconstruction is only supported when pruning weights channels" ) from functools import partial def acceptance_criterion(m, mod_names): # Collect feature-maps only for Conv2d layers, if they are in our modules list. return isinstance( m, torch.nn.Conv2d) and m.distiller_name in mod_names # For feature-map reconstruction we need to collect a representative set # of inter-layer feature-maps from distiller.pruning import FMReconstructionChannelPruner collect_intermediate_featuremap_samples( self.net_wrapper.model, self.net_wrapper.validate, partial(acceptance_criterion, mod_names=modules_list), partial( FMReconstructionChannelPruner.cache_featuremaps_fwd_hook, n_points_per_fm=self.amc_cfg.n_points_per_fm))