예제 #1
0
 def get_named_depthwise_bn(self, prefix=None):
     """Get `{name: module}` pairs of BN after depthwise convolution."""
     res = collections.OrderedDict()
     for i, op in enumerate(self.ops):
         children = list(op.children())
         if self.expand:
             idx_op = 1
         else:
             idx_op = 0
         conv_bn_relu = children[idx_op]
         assert isinstance(conv_bn_relu, ConvBNReLU)
         conv_bn_relu = list(conv_bn_relu.children())
         _, bn, _ = conv_bn_relu
         assert isinstance(bn, nn.BatchNorm2d)
         name = 'ops.{}.{}.1'.format(i, idx_op)
         name = add_prefix(name, prefix)
         res[name] = bn
     return res
예제 #2
0
 def get_named_depthwise_bn(self, prefix=None):
     """Get `{name: module}` pairs of BN after depthwise convolution."""
     res = collections.OrderedDict()
     for i, op in enumerate(self.ops):
         children = list(op.children())
         if self.expand:
             idx_op = 1  # For Atom blocks, InvertedResidual(16, 24, channels=[96, 96, 96], kernel_sizes=[3, 5, 7], expand=True, stride=2)
         else:
             idx_op = 0  # For the first one, InvertedResidual(16, 16, channels=[16], kernel_sizes=[3], expand=False, stride=1)
         conv_bn_relu = children[idx_op]
         assert isinstance(conv_bn_relu, ConvBNReLU)
         conv_bn_relu = list(conv_bn_relu.children())
         _, bn, _ = conv_bn_relu
         assert isinstance(bn, nn.BatchNorm2d)
         name = 'ops.{}.{}.1'.format(i, idx_op)
         name = add_prefix(name, prefix)
         res[name] = bn
     return res
예제 #3
0
 def _build_name(prefix):
     return add_prefix(attr, prefix)
예제 #4
0
def copmress_inverted_residual_channels(m,
                                        masks,
                                        ema=None,
                                        optimizer=None,
                                        prune_info=None,
                                        prefix=None,
                                        verbose=False):
    def update(infos):
        for info in infos:
            if optimizer is not None and info['type'] != 'buffer':
                optimizer.compress_mask(info, verbose=verbose)
            if ema is not None and 'num_batches_tracked' not in info[
                    'var_old_name']:
                ema.compress_mask(info, verbose=verbose)
            if prune_info is not None and issubclass(
                    info['module_class'],
                    nn.BatchNorm2d) and info['type'] == 'variable':
                if prune_info.compress_check_exist(info):
                    prune_info.compress_mask(info, verbose=verbose)
            info['mask_hook'](info['var_new'], info['var_old'], info['mask'])
            if 'post_hook' in info:
                # FIXME(meijieru): bn adjust
                warnings.warn('Do not adjust bn mean!!!')
                # info['post_hook'](info)

    def clean(infos):
        for info in infos:
            if optimizer is not None and info['type'] != 'buffer':
                optimizer.compress_drop(info, verbose=verbose)
            if ema is not None and 'num_batches_tracked' not in info[
                    'var_old_name']:
                ema.compress_drop(info, verbose=verbose)
            if prune_info is not None and issubclass(
                    info['module_class'],
                    nn.BatchNorm2d) and info['type'] == 'variable':
                if prune_info.compress_check_exist(info):
                    prune_info.compress_drop(info, verbose=verbose)

    assert len(m.kernel_sizes) == len(masks)
    hidden_dims = [mask.detach().sum().item() for mask in masks]
    indices = torch.arange(len(m.ops))
    keeps = [num_remain > 0 for num_remain in hidden_dims]
    m.channels, m.kernel_sizes = [
        list(itertools.compress(x, keeps))
        for x in [hidden_dims, m.kernel_sizes]
    ]
    new_ops, new_pw_bn = m._build(m.channels, m.kernel_sizes, m.expand)
    new_indices = torch.arange(len(new_ops))
    if m.expand:
        idx_depth = 1
        idx_proj = 2
    else:
        idx_depth = 0
        idx_proj = 1

    new_ops_padded = _scatter_by_bool(new_ops, keeps)
    new_indices_padded = _scatter_by_bool(new_indices, keeps)

    # update ema, optimizer, module
    pending_clean_infos = []
    pending_adjust_infos = []
    adjust_infos = []
    for new_op, new_indice, old_op, old_indice, mask in zip(
            new_ops_padded, new_indices_padded, m.ops, indices, masks):
        old_op_children = list(old_op.children())
        if new_op is None:  # drop old
            new_op_children = [None for _ in old_op_children]
            pending_infos = pending_clean_infos
        else:
            new_op_children = list(new_op.children())
            pending_infos = pending_adjust_infos
        if m.expand:
            expand_infos = compress_conv_bn_relu(
                new_op_children[0], old_op_children[0], mask,
                add_prefix('ops.{}.0'.format(new_indice), prefix),
                add_prefix('ops.{}.0'.format(old_indice), prefix))
            pending_infos.append(expand_infos)
        depth_infos = compress_conv_bn_relu(
            new_op_children[idx_depth], old_op_children[idx_depth], mask,
            add_prefix('ops.{}.{}'.format(new_indice, idx_depth), prefix),
            add_prefix('ops.{}.{}'.format(old_indice, idx_depth), prefix))
        pending_infos.append(depth_infos)
        proj_infos = compress_conv(
            new_op_children[idx_proj],
            old_op_children[idx_proj],
            mask,
            dim=1,
            prefix_new=add_prefix('ops.{}.{}'.format(new_indice, idx_proj),
                                  prefix),
            prefix_old=add_prefix('ops.{}.{}'.format(old_indice, idx_proj),
                                  prefix))
        pending_infos.append(proj_infos)

        adjust_info = {'active_fn': m.active_fn}
        adjust_info['bias'] = _find_only_one(
            lambda info: issubclass(info['module_class'], nn.BatchNorm2d) and
            'bias' in info['var_old_name'], depth_infos)
        adjust_info['following_conv_weight'] = _find_only_one(
            lambda info: issubclass(info['module_class'], nn.Conv2d) and
            'weight' in info['var_old_name'], proj_infos)
        adjust_infos.append(adjust_info)
    prefix_pw_bn = add_prefix('pw_bn', prefix)
    pw_bn_infos = adjust_bn(new_pw_bn,
                            m.pw_bn,
                            adjust_infos,
                            prefix_new=prefix_pw_bn,
                            prefix_old=prefix_pw_bn)

    if ema is not None:
        ema.compress_start()
    if prune_info is not None:
        prune_info.compress_start()
    update(pw_bn_infos)  # NOTE: must do before following for ema
    for infos in pending_adjust_infos:
        update(infos)
    for infos in pending_clean_infos:  # remove non-use last
        clean(infos)

    del m.ops
    del m.pw_bn
    m.ops, m.pw_bn = new_ops, new_pw_bn