def modify_network(net_current):
    args = net_current.args
    modules = []
    for module_cur in net_current.modules():
        if isinstance(module_cur, Dense) or isinstance(module_cur, Transition):
            modules.append(module_cur)
    for module_cur in modules:

        # get initialization values for the ResBlock to be compressed
        if isinstance(module_cur, Dense):
            weight1 = module_cur.state_dict()['body.2.weight']
        elif isinstance(module_cur, Transition):
            weight1 = module_cur.state_dict()['2.weight']
        else:
            raise NotImplementedError('Do not need to compress the layer ' +
                                      module_cur.__class__.__name__)

        if args.init_method.find('disturbance') >= 0:
            weight1, weight2 = init_weight_proj(weight1,
                                                init_method=args.init_method,
                                                d=0,
                                                s=0.05)
        else:
            weight1, weight2 = init_weight_proj(weight1,
                                                init_method=args.init_method)
        # modify submodules in the ResBlock
        # embed()
        modify_submodules(module_cur)
        # set ResBlock module params
        params = edict({'weight1': weight1, 'weight2': weight2})
        set_module_param(module_cur, params)
def modify_network(net_current):
    args = net_current.args
    modules = []
    for module_cur in net_current.modules():
        if isinstance(module_cur, wide_basic):
            modules.append(module_cur)
    for module_cur in modules:

        # get initialization values for the ResBlock to be compressed
        weight1 = module_cur.state_dict()['conv1.weight']
        bias1 = module_cur.state_dict()['conv1.bias']
        weight2 = module_cur.state_dict()['conv2.weight']
        bias2 = module_cur.state_dict()['conv2.bias']
        if args.init_method.find('disturbance') >= 0:
            weight1, projection1 = init_weight_proj(
                weight1, init_method=args.init_method, d=0, s=0.05)
            weight2, projection2 = init_weight_proj(
                weight2, init_method=args.init_method, d=1, s=0.05)
        else:
            weight1, projection1 = init_weight_proj(
                weight1, init_method=args.init_method)
            weight2, projection2 = init_weight_proj(
                weight2, init_method=args.init_method)
        # modify submodules in the ResBlock
        modify_submodules(module_cur)
        # set ResBlock module params
        params = edict({
            'weight1': weight1,
            'projection1': projection1,
            'bias1': bias1,
            'weight2': weight2,
            'projection2': projection2,
            'bias2': bias2
        })
        set_module_param(module_cur, params)
示例#3
0
def modify_network(net_current):
    args = net_current.args
    modules = []
    for module_cur in net_current.modules():
        if isinstance(module_cur, BasicBlock):
            modules.append(module_cur)
    # skip the first BasicBlock that deals with the input images
    for module_cur in modules[1:]:
        # embed()
        # get initialization values for the Block to be compressed
        weight1 = module_cur.state_dict()['0.weight']
        bias1 = module_cur.state_dict()['0.bias']
        if args.init_method.find('disturbance') >= 0:
            weight1, projection1 = init_weight_proj(
                weight1, init_method=args.init_method, d=0, s=0.05)
        else:
            weight1, projection1 = init_weight_proj(
                weight1, init_method=args.init_method)
        # modify submodules in the ResBlock
        modify_submodules(module_cur)
        # set ResBlock module params
        params = edict({
            'weight1': weight1,
            'projection1': projection1,
            'bias1': bias1
        })
        set_module_param(module_cur, params)
示例#4
0
def compress_resblock_together(net_current, net_parent, data, ckp, args):
    modules = []
    for module_cur in net_current.modules():
        if isinstance(module_cur, ResBlock):
            modules.append(module_cur)
    for module_cur in modules:
        global resblock_counter
        resblock_counter += 1
        module_cur.prune_procedure = args.prune_procedure  # choices final, complete, and undergoing. module_cur.optimization changed in modify_submodules

        # get initialization values for the ResBlock to be compressed
        weight1, bn_weight1, bn_bias1, bn_mean1, bn_var1, _, \
        weight2, bn_weight2, bn_bias2, bn_mean2, bn_var2, _ = [p.data for k, p in module_cur.state_dict().items()]
        embed()
        # bn_weight1 = torch.ones(len(bn_weight1))
        # bn_bias1 = torch.zeros(len(bn_bias1))
        # bn_mean1 = torch.zeros(len(bn_mean1))
        # bn_var1 = torch.ones(len(bn_var1))
        # bn_weight2 = torch.ones(len(bn_weight2))
        # bn_bias2 = torch.zeros(len(bn_bias2))
        # bn_mean2 = torch.zeros(len(bn_mean2))
        # bn_var2 = torch.ones(len(bn_var2))
        if args.prune_init_method.find('disturbance') >= 0:
            weight1, projection1 = init_weight_proj(
                weight1, init_method=args.prune_init_method, d=0, s=0.05)
            weight2, projection2 = init_weight_proj(
                weight2, init_method=args.prune_init_method, d=1, s=0.05)
        else:
            weight1, projection1 = init_weight_proj(
                weight1, init_method=args.prune_init_method)
            weight2, projection2 = init_weight_proj(
                weight2, init_method=args.prune_init_method)
        # embed()
        # modify submodules in the ResBlock
        modify_submodules(module_cur)
        # set ResBlock module params
        params = edict({
            'weight1': weight1,
            'projection1': projection1,
            'bias1': None,
            'bn_weight1': bn_weight1,
            'bn_bias1': bn_bias1,
            'bn_mean1': bn_mean1,
            'bn_var1': bn_var1,
            'weight2': weight2,
            'projection2': projection2,
            'bias2': None,
            'bn_weight2': bn_weight2,
            'bn_bias2': bn_bias2,
            'bn_mean2': bn_mean2,
            'bn_var2': bn_var2
        })
        set_module_param(module_cur, params)
        # delete submodule parameter
        del_submodule_param(module_cur)
        # register module and parameter editting hook
        para_edit_hook_handle = module_cur.register_forward_pre_hook(
            module_edit_hook)
        module_cur.para_edit_hook_handle = para_edit_hook_handle

    ## optimize for the ResBlock
    args_optim = edict({
        'optimizer': args.prune_solver,
        'momentum': 0.9,
        'nesterov': False,
        'betas': (0.9, 0.999),
        'epsilon': 1e-8,
        'weight_decay': args.prune_weight_decay,
        'lr': args.prune_lr,
        'decay': args.prune_decay,
        'gamma': 0.1
    })
    # fine-tuning optimizer
    optimizer = make_optimizer(args_optim,
                               net_current,
                               separate=True,
                               scale=0.1)
    scheduler = make_scheduler(args_optim, optimizer)

    net_current.to(torch.device('cuda'))
    # net_parent.to(torch.device('cuda'))
    # global feature_map  # , feature_map_inte
    for b, (img, label) in enumerate(data):
        if b < 100:

            for i in range(1):
                # if b % 10 == 0:
                print(
                    '\nCompress All ResBlocks, Batch {}, Iteration {}, Last lr {:2.5f}'
                    .format(b, scheduler.last_epoch,
                            scheduler.get_lr()[0]))

                optimizer.zero_grad()
                img = img.to(torch.device('cuda'))
                prediction = net_current(img)

                def compute_loss(modules,
                                 prediction,
                                 label,
                                 lambda_factor=1.0,
                                 q=1):
                    """
                    loss = ||Y - Yc||^2 + lambda * (||A_1||_{2,1} + ||A_2 ^T||_{2,1})
                    """
                    loss_function = nn.MSELoss()
                    loss_proj1 = 0
                    loss_proj2 = 0
                    for m in modules:
                        projection1 = m.projection1.squeeze().t()
                        projection2 = m.projection2.squeeze().t()
                        loss_proj1 = torch.sum(
                            torch.sum(projection1**2, dim=0)
                            **(q / 2))**(1 / q) * lambda_factor
                        loss_proj2 = torch.sum(
                            torch.sum(projection2**2, dim=1)
                            **(q / 2))**(1 / q) * lambda_factor
                    loss_proj1 /= len(modules)
                    loss_proj2 /= len(modules)
                    loss_function_ce = nn.CrossEntropyLoss()
                    # embed()
                    loss_acc = loss_function_ce(prediction, label.cuda())
                    # embed()
                    # loss_proj1 = torch.sum((torch.sum(projection1 ** 2, dim=0) ** q))
                    # loss_proj2 = torch.sum((torch.sum(projection2 ** 2, dim=1) ** q))
                    loss_proj1 = torch.tensor(0).cuda()
                    loss_proj2 = torch.tensor(0).cuda()
                    loss = loss_proj1 + loss_proj2 + loss_acc
                    # embed()
                    with print_array_on_one_line():
                        print(
                            'Current Loss {:>2.5f}: projection1 loss {:>2.5f}, projection2 loss {:>2.5f},'
                            'error rate: {:>2.5f}'.format(
                                loss.detach().cpu().numpy(),
                                loss_proj1.detach().cpu().numpy(),
                                loss_proj2.detach().cpu().numpy(),
                                loss_acc.detach().cpu().numpy()))
                    return loss

                # embed()
                loss = compute_loss(modules,
                                    prediction,
                                    label,
                                    lambda_factor=args.prune_regularization,
                                    q=args.q)
                torch.save(
                    {
                        'img': img,
                        'label': label,
                        'prediction': prediction,
                        'loss': loss
                    }, 'batch.pt')

                # embed()
                loss.backward()
                optimizer.step()
                scheduler.step()
                torch.save(net_current.state_dict(), 'before1.pt')
        else:
            break
示例#5
0
def compress_resblock(module_current, module_parent, net_current, net_parent,
                      data, ckp, args):
    """
    module_current: the current module
    module_parent: the parent module
    net_current: the current full network
    net_parent: the parent full network
    data: used by data driven compression algorithm
    ckp: checkpoint used to write the log
    """

    for (name_cur,
         module_cur), (name_par,
                       module_par) in zip(module_current._modules.items(),
                                          module_parent._modules.items()):
        if isinstance(module_cur, ResBlock):
            global resblock_counter, global_step
            global_step = 0
            resblock_counter += 1
            module_cur.prune_procedure = args.prune_procedure  # choices final, complete, and undergoing. module_cur.optimization changed in modify_submodules

            # get initialization values for the ResBlock to be compressed
            weight1, bn_weight1, bn_bias1, bn_mean1, bn_var1, _,\
                weight2, bn_weight2, bn_bias2, bn_mean2, bn_var2, _ = [p.data for k, p in module_cur.state_dict().items()]
            if args.prune_init_method.find('disturbance') >= 0:
                weight1, projection1 = init_weight_proj(
                    weight1, init_method=args.prune_init_method, d=0)
                weight2, projection2 = init_weight_proj(
                    weight2, init_method=args.prune_init_method, d=1)
            else:
                weight1, projection1 = init_weight_proj(
                    weight1, init_method=args.prune_init_method)
                weight2, projection2 = init_weight_proj(
                    weight2, init_method=args.prune_init_method)
            # modify submodules in the ResBlock
            modify_submodules(module_cur)
            # set ResBlock module params
            params = edict({
                'weight1': weight1,
                'projection1': projection1,
                'bias1': None,
                'bn_weight1': bn_weight1,
                'bn_bias1': bn_bias1,
                'bn_mean1': bn_mean1,
                'bn_var1': bn_var1,
                'weight2': weight2,
                'projection2': projection2,
                'bias2': None,
                'bn_weight2': bn_weight2,
                'bn_bias2': bn_bias2,
                'bn_mean2': bn_mean2,
                'bn_var2': bn_var2
            })
            set_module_param(module_cur, params)
            # delete submodule parameter
            del_submodule_param(module_cur)
            # register module and parameter editting hook
            para_edit_hook_handle = module_cur.register_forward_pre_hook(
                module_edit_hook)
            module_cur.para_edit_hook_handle = para_edit_hook_handle

            ## optimize for the ResBlock
            args_optim = edict({
                'optimizer': args.prune_solver,
                'momentum': 0.9,
                'nesterov': False,
                'betas': (0.9, 0.999),
                'epsilon': 1e-8,
                'weight_decay': args.prune_weight_decay,
                'lr': args.prune_lr,
                'decay': args.prune_decay,
                'gamma': 0.1
            })
            # fine-tuning optimizer
            optimizer = make_optimizer(args_optim,
                                       module_cur,
                                       separate=True,
                                       scale=0.1)
            scheduler = make_scheduler(args_optim, optimizer)
            # add feature map collection hook
            add_feature_map_handle(module_cur, store_output=True)
            add_feature_map_handle(module_par, store_output=True)
            # used to collect the feature map of the intermediate layers
            # add_feature_map_handle(module_cur._modules['body']._modules['1'], store_input=True, store_output=True)
            # add_feature_map_handle(module_par._modules['body']._modules['1'], store_input=True, store_output=True)

            net_current.to(torch.device('cuda'))
            net_parent.to(torch.device('cuda'))
            global feature_map  #, feature_map_inter
            # SGD optimization
            for b, (img, label) in enumerate(data):
                if b < args.prune_iteration:

                    for i in range(1):
                        # if b % 10 == 0:
                        print(
                            '\nCompress ResBlock {}, Batch {}, Iteration {}, Last lr {:2.5f}'
                            .format(resblock_counter - 1, b,
                                    scheduler.last_epoch,
                                    scheduler.get_lr()[0]))

                        optimizer.zero_grad()
                        img = img.to(torch.device('cuda'))
                        prediction = net_current(img)
                        output_comp = feature_map['output']
                        # embed()
                        # intermediate feature map
                        # output_comp_inter = feature_map_inter['output']
                        # input_comp_inter = feature_map_inter['input']
                        net_parent(img)
                        output_parent = feature_map['output'].detach()
                        # intermediate feature map
                        # output_parent_inter = feature_map_inter['output']
                        # input_parent_inter = feature_map_inter['input']
                        padding = 1
                        normalize = True

                        # intermediate feature map
                        # p = os.path.join(ckp.args.dir_save, ckp.args.save, 'feature_grid_inter')
                        # if not os.path.exists(p):
                        #     os.makedirs(p)
                        # grid_comp_inter = feature_visualize(output_comp_inter, 16, 16, normalize=normalize, padding=padding)
                        # grid_parent_inter = feature_visualize(output_parent_inter, 16, 16, normalize=normalize, padding=padding)
                        # save_tensor_image(grid_comp_inter, p + '/grid_comp{}.png'.format(scheduler.last_epoch))
                        # save_tensor_image(grid_parent_inter, p + '/grid_parent{}.png'.format(scheduler.last_epoch))
                        # save_tensor_image(grid_parent_inter - grid_comp_inter, p + '/grid_sub{}.png'.format(scheduler.last_epoch))
                        # grid_comp_in = feature_visualize(input_comp_inter, 16, 16, normalize=normalize, padding=padding)
                        # grid_parent_in = feature_visualize(input_parent_inter, 16, 16, normalize=normalize, padding=padding)
                        # save_tensor_image(grid_comp_in, p + '/in_grid_comp{}.png'.format(scheduler.last_epoch))
                        # save_tensor_image(grid_parent_in, p + '/in_grid_parent{}.png'.format(scheduler.last_epoch))
                        # save_tensor_image(grid_parent_in - grid_comp_in, p + '/in_grid_sub{}.png'.format(scheduler.last_epoch))

                        p = os.path.join(ckp.args.dir_save, ckp.args.save,
                                         'feature_grid')
                        if not os.path.exists(p):
                            os.makedirs(p)
                        grid_comp = feature_visualize(output_comp,
                                                      16,
                                                      16,
                                                      normalize=normalize,
                                                      padding=padding)
                        grid_parent = feature_visualize(output_parent,
                                                        16,
                                                        16,
                                                        normalize=normalize,
                                                        padding=padding)
                        save_tensor_image(
                            grid_comp, p +
                            '/grid_comp{}.png'.format(scheduler.last_epoch))
                        save_tensor_image(
                            grid_parent, p +
                            '/grid_parent{}.png'.format(scheduler.last_epoch))
                        save_tensor_image(
                            grid_parent - grid_comp,
                            p + '/grid_sub{}.png'.format(scheduler.last_epoch))

                        # output_parent.requires_grad = False
                        loss = compute_loss(
                            output_parent,
                            output_comp,
                            module_cur.projection1,
                            module_cur.projection2,
                            prediction,
                            label,
                            lambda_factor=args.prune_regularization,
                            q=args.q)

                        if args.prune_procedure == 'complete':
                            optimizer.param_groups[0]['params'] = list(
                                filter(lambda x: x.requires_grad,
                                       module_cur.parameters()))
                            # for p in optimizer.param_groups[0]['params']:
                            #     if p.grad is not None:
                            #         p.grad.data = torch.zeros_like(p)
                        loss.backward()
                        optimizer.step()
                        scheduler.step()

                        # check the the change rate of projection, and the params and buffers of batchnorm
                        # global projection
                        # if b >= 1:
                        #     x = module_cur.projection1 / projection['projection1']
                        #     print('Divide projection1 max: {:>2.5f}, min {:>2.5f}'.
                        #           format(x.detach().cpu().max().numpy(), x.detach().cpu().min().numpy()))
                        #
                        #     x = module_cur.bn_weight1 / projection['bn_weight1']
                        #     print('Divide bn weight1  max: {:>2.5f}, min {:>2.5f}'.
                        #           format(x.detach().cpu().max().numpy(), x.detach().cpu().min().numpy()))
                        #
                        #     x = module_cur.bn_bias1 / projection['bn_bias1']
                        #     print('Divide bn bias1    max: {:>2.5f}, min {:>2.5f}'.
                        #           format(x.detach().cpu().max().numpy(), x.detach().cpu().min().numpy()))
                        #
                        #     x = module_cur.bn_mean1 / projection['bn_mean1']
                        #     print('Divide bn_mean1    max: {:>2.5f}, min {:>2.5f}'.
                        #           format(x.detach().cpu().max().numpy(), x.detach().cpu().min().numpy()))
                        #
                        #     x = module_cur.bn_var1 / projection['bn_var1']
                        #     print('Divide bn_var1     max: {:>2.5f}, min {:>2.5f}'.
                        #           format(x.detach().cpu().max().numpy(), x.detach().cpu().min().numpy()))
                        # projection = {'projection1': module_cur.projection1.clone(), 'projection2': module_cur.projection2.clone(),
                        #               'bn_weight1': module_cur.bn_weight1.clone(), 'bn_bias1': module_cur.bn_bias1,
                        #               'bn_mean1': module_cur.bn_mean1.clone(), 'bn_var1': module_cur.bn_var1}

                        # check the gradients of weight and projection matrix
                        print(
                            'Projection1 grad max: {:2.5f}, min: {:2.5f}, Weight1 grad max: {:2.5f}, min: {:2.5f}'
                            .format(
                                module_cur.projection1.grad.max().detach().cpu(
                                ).numpy(),
                                module_cur.projection1.grad.min().detach().cpu(
                                ).numpy(),
                                module_cur.weight1.grad.max().detach().cpu(
                                ).numpy(),
                                module_cur.weight1.grad.min().detach().cpu().
                                numpy()))
                        print(
                            'Projection2 grad max: {:2.5f}, min: {:2.5f}, Weight2 grad max: {:2.5f}, min: {:2.5f}'
                            .format(
                                module_cur.projection2.grad.max().detach().cpu(
                                ).numpy(),
                                module_cur.projection1.grad.min().detach().cpu(
                                ).numpy(),
                                module_cur.weight2.grad.max().detach().cpu(
                                ).numpy(),
                                module_cur.weight2.grad.min().detach().cpu().
                                numpy()))

                else:
                    break

            # remove parameter editting hook
            module_cur.para_edit_hook_handle.remove()

            # remove feature map collection hook
            remove_feature_map_handle(module_cur, store_input=True)
            remove_feature_map_handle(module_par, store_output=True)
            # used to collect the feature map of the intermediate layers
            # remove_feature_map_handle(module_cur._modules['body']._modules['1'], store_input=True, store_output=True)
            # remove_feature_map_handle(module_par._modules['body']._modules['1'], store_input=True, store_output=True)

            # prepare parameters for fine-tuning
            module_cur.optimization = False
            print('\nPrepare for finetuning.')
            prune(module_cur)
        elif module_cur is None:
            continue
        else:
            compress_resblock(module_cur, module_par, net_current, net_parent,
                              data, ckp, args)