def update_channels_to_winnow(self, list_of_zero_in_channels: List[int], list_of_zero_out_channels: List[int]): """ Sets the parameters associated with Mask Propagation :param list_of_zero_in_channels: List of in channels to winnow :param list_of_zero_out_channels: List of out channels to winnow """ if self._op_type not in get_conv_ops_for_api(self._model_api) and \ self._op_type not in get_linear_ops_for_api(self._model_api): raise ValueError(" Module type %s is not allowed to be winnowed" % self._op_type) if self._op_type in get_conv_ops_for_api(self._model_api): num_in_channels = self._num_in_channels in_channels_total_and_winnow = (num_in_channels, list_of_zero_in_channels) num_out_channels = self._num_out_channels out_channels_total_and_winnow = (num_out_channels, list_of_zero_out_channels) else: num_in_channels = self._num_in_channels in_channels_total_and_winnow = (num_in_channels, list_of_zero_in_channels) num_out_channels = self._num_out_channels out_channels_total_and_winnow = (num_out_channels, list_of_zero_out_channels) self._update_conv_linear_channels_to_winnow( in_channels_total_and_winnow, out_channels_total_and_winnow)
def reduce_modules(self, list_of_ops_to_reduce: List): """ For the Ops in the list, reduce he corresponding modules. Reduce includes reducing the parameter tensors associated with the module as well as prepending/appending DownSample/UpSample layer to the module. :param list_of_ops_to_reduce: list of Ops whose associated modules need to be reduced. :return: dictionary mapping names of reduced modules to the modules themselves """ modified_modules = {} for an_op in list_of_ops_to_reduce: if an_op.type in get_conv_ops_for_api(ModelApi.pytorch): a_conv_module = self._reduce_conv_module(an_op) modified_modules[an_op.dotted_name] = a_conv_module elif an_op.type in ['BatchNorm2d', 'batch_norm']: a_bn_module = self._reduce_batchnorm_module(an_op) modified_modules[an_op.dotted_name] = a_bn_module else: logger.debug("reduce_modules(): skipping: %s", an_op.dotted_name) return modified_modules
def _adjust_downstream_op_masks(self, downstream_op: Op, modified_mask: List[int], model_api: ModelApi): """ Starting with the downstream_op, adjust the input and output masks for the Ops until a Conv Op is reached. :param downstream_op: the starting downstream op :param modified_mask: the mask to be set for the downstream Ops :param model_api: either tensorflow or pytorch """ if downstream_op.type not in get_conv_ops_for_api(model_api): downstream_op_mask = self._op_to_mask_dict[downstream_op] if isinstance(self._op_to_mask_dict[downstream_op].internal_connectivity, SplitInternalConnectivity): # Downstream Op has single input and multiple outputs. downstream_op_mask.set_input_channel_mask(0, modified_mask) downstream_out_masks = downstream_op_mask.output_channel_masks num_out_masks = len(downstream_out_masks) for index in range(num_out_masks): downstream_op_mask.set_output_channel_mask(index, modified_mask) self._adjust_downstream_op_masks(downstream_op.output.consumers[index], modified_mask, model_api) elif not isinstance(self._op_to_mask_dict[downstream_op].internal_connectivity, StopInternalConnectivity): # Downstream Op has single input and single output. downstream_op_mask.set_input_channel_mask(0, modified_mask) downstream_op_mask.set_output_channel_mask(0, modified_mask) logger.debug("Masks adjusted for: %s", downstream_op.dotted_name) if downstream_op.output: self._adjust_downstream_op_masks(downstream_op.output.consumers[0], modified_mask, model_api) else: # Stop propagating downstream if we hit a stop connectivity op return
def _set_default_input_output_masks(self, in_channels: int, out_channels: int): """ Based on the Op type, sets default input and output channel masks. :param in_channels: The number of input channels :param out_channels: The number of output channels """ op_connectivity = OpConnectivity.get_op_connectivity( self._model_api, self._op_type) if op_connectivity == ConnectivityType.null: if self._op_type in get_conv_ops_for_api(self._model_api) or \ self._op_type in get_linear_ops_for_api(self._model_api): self._set_default_masks_for_conv_and_linear() else: self._set_default_masks_for_null_and_stop_connectivity_ops( in_channels, out_channels, is_null_connectivity=True) elif op_connectivity == ConnectivityType.direct: # Necessary to switch connectivity of padding to null when adjusting channel size since staying at direct # connectivity will cause input and output channel sizes to become equal if self._model_api == ModelApi.tensorflow and self._op_type in ["Pad", "PadV2", "MirrorPad"] and \ in_channels != out_channels: self._set_default_masks_for_null_and_stop_connectivity_ops( in_channels, out_channels, is_null_connectivity=True) else: self._set_default_masks_for_direct_connectivity_ops( in_channels, out_channels) elif op_connectivity == ConnectivityType.add: in_masks_list, out_masks_list = self._create_masks_list_for_multi_input_single_output_ops( out_channels) self._internal_connectivity = AddInternalConnectivity( in_masks_list, out_masks_list) elif op_connectivity == ConnectivityType.concat: in_masks_list, out_masks_list = self._create_masks_list_for_multi_input_single_output_ops( out_channels) self._internal_connectivity = ConcatInternalConnectivity( in_masks_list, out_masks_list) elif op_connectivity == ConnectivityType.split: in_masks_list, out_masks_list = self._create_masks_list_for_single_input_multi_output_ops( in_channels) self._internal_connectivity = SplitInternalConnectivity( in_masks_list, out_masks_list) elif op_connectivity == ConnectivityType.skip: in_masks_list = [] out_masks_list = [] self._internal_connectivity = SkipInternalConnectivity( in_masks_list, out_masks_list) elif op_connectivity == ConnectivityType.stop: self._set_default_masks_for_null_and_stop_connectivity_ops( in_channels, out_channels, is_null_connectivity=False) else: logger.error("Unsupported op_type %s, dotted %s, input_ops: %s", self._op_type, self._dotted_name, self._op_input_ops) raise NotImplementedError()
def _create_masks(self): """ Create masks for each op in the connected graph that leads to a conv op """ for op in self._ops.values(): # TODO: Only creating masks for ops that lead to conv ops was only tested for TF. See if the same can be # done for pytorch, where we traditionally created masks for all ops. if self._model_api == ModelApi.tensorflow: if op.type in get_conv_ops_for_api(self._model_api): self._create_masks_for_op_and_all_ancestors(op) else: self._create_masks_for_op_and_all_ancestors(op)
def _prepend_downsample_layer_to_module( self, op: Operation, input_producer_op_out_mask: List[int]): """Creates a Sequential by prepending a Downsample layer to the module associated with the Op. Replaces the module with the Sequential at the module's parent. :param op: The Op to which a Downsample layer is prepended. :param input_producer_op_out_mask: List of channels for the input op; 0 to winnow and 1 to keep channel :return: """ logger.debug( "Prepend Downsample: Op dotted name: %s, Op type: %s next down module dotted name: %s, type: %s", op.dotted_name, op.type, op.output.consumers[0].dotted_name, op.output.consumers[0].type) if op.type in get_conv_ops_for_api(ModelApi.pytorch): conv_op = op else: conv_op = get_next_conv_op_for_op_with_single_consumer(op) module = conv_op.get_module() parent_module_ref, var_name = self._parent_module_ref[module] op_mask = self._op_to_mask_dict[op] op_in_masks = op_mask.input_channel_masks keep_indices = get_indices_among_ones_of_overlapping_ones( input_producer_op_out_mask, op_in_masks[0]) keep_indices_tensor = torch.tensor(keep_indices) # pylint: disable=not-callable if self._using_cuda: keep_indices_tensor = keep_indices_tensor.cuda() down_sample = DownsampleLayer(keep_indices_tensor) # Create a sequential of the Downsample layer and the module seq = torch.nn.Sequential(down_sample, module) # Set the Sequential as the child for parent module. setattr(parent_module_ref, var_name, seq) logger.info("Prepended Downsample Layer to %s", op.dotted_name) return seq
def _set_default_masks_for_conv_and_linear(self): """ Set the default input and output masks for Conv and Linear modules. """ if self._op_type in get_conv_ops_for_api(self._model_api): num_input_masks = len(self._op_input_ops) input_mask_length = self._num_in_channels if self._op_output: num_output_masks = len(self._op_output.consumers) output_mask_length = self._num_out_channels else: num_output_masks = 0 output_mask_length = 0 in_mask_length_list, out_mask_length_list = self._create_input_output_mask_and_length_tuples( num_input_masks, input_mask_length, num_output_masks, output_mask_length) # Group value of 1 represents a normal Conv2d Op which will have Null Connectivity. # Group value of anything else represents a depthwise convolution which will have Direct Connectivity if self._groups == 1: self._internal_connectivity = NullInternalConnectivity( in_mask_length_list, out_mask_length_list) else: self._internal_connectivity = DirectInternalConnectivity( in_mask_length_list, out_mask_length_list) else: num_input_masks = len(self._op_input_ops) input_mask_length = self._num_in_channels if self._op_output: num_output_masks = len(self._op_output.consumers) else: num_output_masks = 0 output_mask_length = self._num_out_channels in_mask_length_list, out_mask_length_list = self._create_input_output_mask_and_length_tuples( num_input_masks, input_mask_length, num_output_masks, output_mask_length) self._internal_connectivity = NullInternalConnectivity( in_mask_length_list, out_mask_length_list)