def get_named_depthwise_bn(self, prefix=None): """Get `{name: module}` pairs of BN after depthwise convolution.""" res = collections.OrderedDict() for i, op in enumerate(self.ops): children = list(op.children()) if self.expand: idx_op = 1 else: idx_op = 0 conv_bn_relu = children[idx_op] assert isinstance(conv_bn_relu, ConvBNReLU) conv_bn_relu = list(conv_bn_relu.children()) _, bn, _ = conv_bn_relu assert isinstance(bn, nn.BatchNorm2d) name = 'ops.{}.{}.1'.format(i, idx_op) name = add_prefix(name, prefix) res[name] = bn return res
def get_named_depthwise_bn(self, prefix=None): """Get `{name: module}` pairs of BN after depthwise convolution.""" res = collections.OrderedDict() for i, op in enumerate(self.ops): children = list(op.children()) if self.expand: idx_op = 1 # For Atom blocks, InvertedResidual(16, 24, channels=[96, 96, 96], kernel_sizes=[3, 5, 7], expand=True, stride=2) else: idx_op = 0 # For the first one, InvertedResidual(16, 16, channels=[16], kernel_sizes=[3], expand=False, stride=1) conv_bn_relu = children[idx_op] assert isinstance(conv_bn_relu, ConvBNReLU) conv_bn_relu = list(conv_bn_relu.children()) _, bn, _ = conv_bn_relu assert isinstance(bn, nn.BatchNorm2d) name = 'ops.{}.{}.1'.format(i, idx_op) name = add_prefix(name, prefix) res[name] = bn return res
def _build_name(prefix): return add_prefix(attr, prefix)
def copmress_inverted_residual_channels(m, masks, ema=None, optimizer=None, prune_info=None, prefix=None, verbose=False): def update(infos): for info in infos: if optimizer is not None and info['type'] != 'buffer': optimizer.compress_mask(info, verbose=verbose) if ema is not None and 'num_batches_tracked' not in info[ 'var_old_name']: ema.compress_mask(info, verbose=verbose) if prune_info is not None and issubclass( info['module_class'], nn.BatchNorm2d) and info['type'] == 'variable': if prune_info.compress_check_exist(info): prune_info.compress_mask(info, verbose=verbose) info['mask_hook'](info['var_new'], info['var_old'], info['mask']) if 'post_hook' in info: # FIXME(meijieru): bn adjust warnings.warn('Do not adjust bn mean!!!') # info['post_hook'](info) def clean(infos): for info in infos: if optimizer is not None and info['type'] != 'buffer': optimizer.compress_drop(info, verbose=verbose) if ema is not None and 'num_batches_tracked' not in info[ 'var_old_name']: ema.compress_drop(info, verbose=verbose) if prune_info is not None and issubclass( info['module_class'], nn.BatchNorm2d) and info['type'] == 'variable': if prune_info.compress_check_exist(info): prune_info.compress_drop(info, verbose=verbose) assert len(m.kernel_sizes) == len(masks) hidden_dims = [mask.detach().sum().item() for mask in masks] indices = torch.arange(len(m.ops)) keeps = [num_remain > 0 for num_remain in hidden_dims] m.channels, m.kernel_sizes = [ list(itertools.compress(x, keeps)) for x in [hidden_dims, m.kernel_sizes] ] new_ops, new_pw_bn = m._build(m.channels, m.kernel_sizes, m.expand) new_indices = torch.arange(len(new_ops)) if m.expand: idx_depth = 1 idx_proj = 2 else: idx_depth = 0 idx_proj = 1 new_ops_padded = _scatter_by_bool(new_ops, keeps) new_indices_padded = _scatter_by_bool(new_indices, keeps) # update ema, optimizer, module pending_clean_infos = [] pending_adjust_infos = [] adjust_infos = [] for new_op, new_indice, old_op, old_indice, mask in zip( new_ops_padded, new_indices_padded, m.ops, indices, masks): old_op_children = list(old_op.children()) if new_op is None: # drop old new_op_children = [None for _ in old_op_children] pending_infos = pending_clean_infos else: new_op_children = list(new_op.children()) pending_infos = pending_adjust_infos if m.expand: expand_infos = compress_conv_bn_relu( new_op_children[0], old_op_children[0], mask, add_prefix('ops.{}.0'.format(new_indice), prefix), add_prefix('ops.{}.0'.format(old_indice), prefix)) pending_infos.append(expand_infos) depth_infos = compress_conv_bn_relu( new_op_children[idx_depth], old_op_children[idx_depth], mask, add_prefix('ops.{}.{}'.format(new_indice, idx_depth), prefix), add_prefix('ops.{}.{}'.format(old_indice, idx_depth), prefix)) pending_infos.append(depth_infos) proj_infos = compress_conv( new_op_children[idx_proj], old_op_children[idx_proj], mask, dim=1, prefix_new=add_prefix('ops.{}.{}'.format(new_indice, idx_proj), prefix), prefix_old=add_prefix('ops.{}.{}'.format(old_indice, idx_proj), prefix)) pending_infos.append(proj_infos) adjust_info = {'active_fn': m.active_fn} adjust_info['bias'] = _find_only_one( lambda info: issubclass(info['module_class'], nn.BatchNorm2d) and 'bias' in info['var_old_name'], depth_infos) adjust_info['following_conv_weight'] = _find_only_one( lambda info: issubclass(info['module_class'], nn.Conv2d) and 'weight' in info['var_old_name'], proj_infos) adjust_infos.append(adjust_info) prefix_pw_bn = add_prefix('pw_bn', prefix) pw_bn_infos = adjust_bn(new_pw_bn, m.pw_bn, adjust_infos, prefix_new=prefix_pw_bn, prefix_old=prefix_pw_bn) if ema is not None: ema.compress_start() if prune_info is not None: prune_info.compress_start() update(pw_bn_infos) # NOTE: must do before following for ema for infos in pending_adjust_infos: update(infos) for infos in pending_clean_infos: # remove non-use last clean(infos) del m.ops del m.pw_bn m.ops, m.pw_bn = new_ops, new_pw_bn