Esempio n. 1
0
def prune_one_cnn(block_index, conv_index, conv_tensor, cuda, in_channels_keep_index, initializer_fn, keep_index,
                  layer_index, model, model_adapter, model_architecture, name, param_type, parameters, reset_index):
    new_conv_tensor = create_conv_tensor(conv_tensor, in_channels_keep_index, initializer_fn, keep_index,
                                         name,
                                         parameters, reset_index).to(cuda)
    model_adapter.set_layer(model, param_type, new_conv_tensor, conv_index, layer_index, block_index)
    in_channels_keep_index.append(keep_index.sort()[0])
    if name not in model_architecture:
        model_architecture[name] = []
    model_architecture[name].append(keep_index.shape[0])
def rpgp_target(epochs, target_prune_rate, remove_ratio, test_loader,
                tcriterion_func, **kwargs):
    optimizer = kwargs["optimizer"]
    model = kwargs["model"]
    cuda = kwargs["cuda"]
    train_loader = kwargs["train_loader"]
    train_ratio = kwargs["train_ratio"]
    initializer_fn = kwargs["initializer_fn"]
    model_adapter = kwargs["model_adapter"]

    logger = kwargs["logger"]
    logger_id = ""
    if "logger_id" in kwargs:
        logger_id = kwargs["logger_id"]

    is_break = False
    if "is_break" in kwargs and kwargs["is_break"]:
        is_break = True

    scheduler = None
    if "scheduler" in kwargs:
        scheduler = kwargs["scheduler"]

    final_fn = 0
    if "final_fn" in kwargs and kwargs["final_fn"]:
        final_fn = kwargs["final_fn"]
    is_expo = kwargs["is_expo"]

    loss_acc = []
    type_list = []
    finished_list = False
    model_architecture = {}
    removed_filters_total = 0
    forced_remove = False
    parameters_hard_removed_total = 0
    decay_rates_c = {}
    original_c = {}

    if is_expo:
        for name, parameters in model.named_parameters():
            param_type, tensor_index, layer_index, block_index = model_adapter.get_param_type_and_layer_index(
                name)
            if param_type == ParameterType.CNN_WEIGHTS or param_type == ParameterType.DOWNSAMPLE_WEIGHTS:
                decay_rates_c[name] = (np.log(parameters.shape[0]) - np.log(
                    parameters.shape[0] * (1 - target_prune_rate))) / epochs
                original_c[name] = parameters.shape[0]
                model_architecture[name] = []
    else:
        for name, parameters in model.named_parameters():
            param_type, tensor_index, layer_index, block_index = model_adapter.get_param_type_and_layer_index(
                name)
            if param_type == ParameterType.CNN_WEIGHTS or param_type == ParameterType.DOWNSAMPLE_WEIGHTS:
                decay_rates_c[
                    name] = target_prune_rate * parameters.shape[0] / epochs
                original_c[name] = parameters.shape[0]
                model_architecture[name] = []

    for epoch in range(1, epochs + 1):
        if scheduler is not None:
            scheduler.step()
        model.train()

        # One epoch step gradient for target
        optimizer.zero_grad()

        start = time.clock()
        total_loss, prune_index_dict = tcriterion_func(model, model_adapter,
                                                       optimizer, cuda,
                                                       train_loader, is_break)
        end = time.clock()
        acc = eval(model, cuda, test_loader, is_break)

        if logger is not None:
            logger.log_scalar("ppgp_target_{}_epoch_time".format(logger_id),
                              start - end, epoch)
            logger.log_scalar("ppgp_target_{}_training_loss".format(logger_id),
                              total_loss, epoch)
            logger.log_scalar(
                "ppgp_target_{}_before_target_val_acc".format(logger_id), acc,
                epoch)

        if epoch == epochs:
            remove_ratio = 1.

        out_channels_keep_indexes = []
        in_channels_keep_indexes = []
        reset_indexes = []
        original_out_channels = 0
        first_fc = False
        current_ids = {}
        start_index = None
        last_start_conv = None
        last_keep_index = None
        removed_filters_total_epoch = 0
        reset_filters_total_epoch = 0
        parameters_hard_removed_per_epoch = 0
        parameters_reset_removed = 0

        for name, parameters in model.named_parameters():

            current_ids[name] = id(parameters)
            param_type, tensor_index, layer_index, block_index = model_adapter.get_param_type_and_layer_index(
                name)
            if not finished_list:
                type_list.append(param_type)

            if layer_index == -1:
                # Handling CNN and BN before Resnet
                if param_type == ParameterType.CNN_WEIGHTS:
                    sorted_filters_index = prune_index_dict[name]
                    conv_tensor = model_adapter.get_layer(
                        model, param_type, tensor_index, layer_index,
                        block_index)
                    original_out_channels = parameters.shape[
                        0]  # conv_tensor.out_channels

                    if is_expo:
                        num_remain_target_prune = original_c[name] * np.exp(
                            -decay_rates_c[name] * (epoch + 1))
                    else:
                        num_remain_target_prune = original_c[
                            name] - decay_rates_c[
                                name] * epoch  #Maybe have to +1 here ?

                    keep_index, reset_index = get_prune_index_target_with_reset(
                        original_out_channels, num_remain_target_prune,
                        remove_ratio, sorted_filters_index, forced_remove)
                    if reset_index is not None:
                        keep_index = torch.cat((keep_index, reset_index))
                    new_conv_tensor = create_conv_tensor(
                        conv_tensor, out_channels_keep_indexes, initializer_fn,
                        keep_index, reset_index).to(cuda)
                    model_adapter.set_layer(model, param_type, new_conv_tensor,
                                            tensor_index, layer_index,
                                            block_index)

                    if name not in model_architecture:
                        model_architecture[name] = []
                    model_architecture[name].append(keep_index.shape[0])

                    removed_filters_total_epoch += original_out_channels - keep_index.shape[
                        0]
                    if reset_index is not None:
                        reset_filters_total_epoch += len(reset_index)

                    in_c = 3
                    if len(out_channels_keep_indexes) != 0:
                        in_c = out_channels_keep_indexes[-1].shape[0]

                    parameters_hard_removed_per_epoch += (original_out_channels - keep_index.shape[0]) * \
                                                         in_c * parameters.shape[2:].numel()
                    parameters_reset_removed += 0 if reset_index is None or len(
                        reset_index) == 0 else len(
                            reset_index) * in_c * parameters.shape[2:].numel()

                    start_index = (keep_index.sort()[0], reset_index)
                    reset_indexes.append(reset_index)
                    if out_channels_keep_indexes is not None and len(
                            out_channels_keep_indexes) != 0:
                        in_channels_keep_indexes.append(
                            out_channels_keep_indexes[-1].sort()[0])
                    else:
                        in_channels_keep_indexes.append(None)
                    out_channels_keep_indexes.append(keep_index.sort()[0])
                elif param_type == ParameterType.CNN_BIAS:
                    reset_indexes.append(reset_indexes[-1])
                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1])

                elif param_type == ParameterType.BN_WEIGHT:
                    bn_tensor = model_adapter.get_layer(
                        model, param_type, tensor_index, layer_index,
                        block_index)

                    keep_index = out_channels_keep_indexes[-1]
                    reset_index = reset_indexes[-1]

                    n_bn = create_new_bn(bn_tensor, keep_index, reset_index)
                    model_adapter.set_layer(model, param_type, n_bn,
                                            tensor_index, layer_index,
                                            block_index)
                    reset_indexes.append(reset_index)
                    if out_channels_keep_indexes is not None or len(
                            out_channels_keep_indexes) != 0:
                        in_channels_keep_indexes.append(
                            out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.BN_BIAS:
                    reset_indexes.append(reset_index)
                    if out_channels_keep_indexes is not None or len(
                            out_channels_keep_indexes) != 0:
                        in_channels_keep_indexes.append(
                            out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.FC_WEIGHTS and first_fc == False:
                    fc_tensor = model_adapter.get_layer(
                        model, param_type, tensor_index, layer_index,
                        block_index)

                    new_fc_weight = prune_fc_like(
                        fc_tensor.weight.data, out_channels_keep_indexes[-1],
                        original_out_channels)

                    new_fc_bias = None
                    if fc_tensor.bias is not None:
                        new_fc_bias = fc_tensor.bias.data
                    new_fc_tensor = nn.Linear(new_fc_weight.shape[1],
                                              new_fc_weight.shape[0],
                                              bias=new_fc_bias
                                              is not None).to(cuda)
                    new_fc_tensor.weight.data = new_fc_weight
                    if fc_tensor.bias is not None:
                        new_fc_tensor.bias.data = new_fc_bias
                    model_adapter.set_layer(model, param_type, new_fc_tensor,
                                            tensor_index, layer_index,
                                            block_index)
                    first_fc = True
                    finished_list = True

            else:

                if param_type == ParameterType.CNN_WEIGHTS:

                    if tensor_index == 1:
                        sorted_filters_index = prune_index_dict[name]
                        conv_tensor = model_adapter.get_layer(
                            model, param_type, tensor_index, layer_index,
                            block_index)
                        original_out_channels = parameters.shape[
                            0]  # conv_tensor.out_channels

                        if is_expo:
                            num_remain_target_prune = original_c[
                                name] * np.exp(-decay_rates_c[name] *
                                               (epoch + 1))
                        else:
                            num_remain_target_prune = original_c[
                                name] - decay_rates_c[
                                    name] * epoch  #Maybe have to +1 here ?

                        keep_index, reset_index = get_prune_index_target_with_reset(
                            original_out_channels, num_remain_target_prune,
                            remove_ratio, sorted_filters_index, forced_remove)

                        if reset_index is not None:
                            keep_index = torch.cat((keep_index, reset_index))
                        new_conv_tensor = create_conv_tensor(
                            conv_tensor, out_channels_keep_indexes,
                            initializer_fn, keep_index, reset_index).to(cuda)
                        model_adapter.set_layer(model, param_type,
                                                new_conv_tensor, tensor_index,
                                                layer_index, block_index)

                        if name not in model_architecture:
                            model_architecture[name] = []
                        model_architecture[name].append(keep_index.shape[0])

                        removed_filters_total_epoch += original_out_channels - keep_index.shape[
                            0]
                        if reset_index is not None:
                            reset_filters_total_epoch += len(reset_index)

                        in_c = conv_tensor.in_channels
                        if len(out_channels_keep_indexes) != 0:
                            in_c = out_channels_keep_indexes[-1].shape[0]

                        parameters_hard_removed_per_epoch += (original_out_channels - keep_index.shape[0]) * \
                                                             in_c * parameters.shape[
                                                                                                      2:].numel()
                        parameters_reset_removed += 0 if reset_index is None or len(
                            reset_index
                        ) == 0 else len(
                            reset_index) * in_c * parameters.shape[2:].numel()

                        reset_indexes.append(reset_index)
                        if out_channels_keep_indexes is not None or len(
                                out_channels_keep_indexes) != 0:
                            in_channels_keep_indexes.append(
                                out_channels_keep_indexes[-1].sort()[0])
                        out_channels_keep_indexes.append(keep_index.sort()[0])

                    elif tensor_index == 2:

                        downsample_cnn, d_name = model_adapter.get_downsample(
                            model, layer_index, block_index)
                        if downsample_cnn is not None:

                            sorted_filters_index = prune_index_dict[d_name]
                            original_out_channels = parameters.shape[
                                0]  # conv_tensor.out_channels
                            last_keep_index, _ = start_index

                            if is_expo:
                                num_remain_target_prune = original_c[
                                    d_name] * np.exp(-decay_rates_c[d_name] *
                                                     (epoch + 1))
                            else:
                                num_remain_target_prune = original_c[
                                    d_name] - decay_rates_c[
                                        d_name] * epoch  # Maybe have to +1 here ?

                            num_remain_target_prune = original_c[
                                d_name]  # No prune for downsampling test

                            keep_index, reset_index = get_prune_index_target_with_reset(
                                original_out_channels, num_remain_target_prune,
                                remove_ratio, sorted_filters_index,
                                forced_remove)
                            if reset_index is not None:
                                keep_index = torch.cat(
                                    (keep_index, reset_index))
                            last_start_conv = create_conv_tensor(
                                downsample_cnn, [last_keep_index],
                                initializer_fn, keep_index,
                                reset_index).to(cuda)
                            last_start_conv = [
                                last_start_conv, 0, layer_index, block_index
                            ]

                            if d_name not in model_architecture:
                                model_architecture[d_name] = []
                            model_architecture[d_name].append(
                                keep_index.shape[0])
                            if reset_index is not None:
                                removed_filters_total_epoch += original_out_channels - keep_index.shape[
                                    0] - len(reset_index)
                                reset_filters_total_epoch += len(reset_index)
                            else:
                                removed_filters_total_epoch += original_out_channels - keep_index.shape[
                                    0]
                            parameters_hard_removed_per_epoch += (original_out_channels - keep_index.shape[0]) * \
                                                                 last_keep_index.shape[0] * parameters.shape[2:].numel()
                            parameters_reset_removed += 0 if reset_index is None or len(
                                reset_index) == 0 else len(
                                    reset_index) * last_keep_index.shape[
                                        0] * parameters.shape[2:].numel()

                            start_index = (keep_index.sort()[0], reset_index)

                        original_out_channels = parameters.shape[0]
                        conv_tensor = model_adapter.get_layer(
                            model, param_type, tensor_index, layer_index,
                            block_index)
                        keep_index, reset_index = start_index

                        new_conv_tensor = create_conv_tensor(
                            conv_tensor, out_channels_keep_indexes,
                            initializer_fn, keep_index, reset_index).to(cuda)
                        model_adapter.set_layer(model, param_type,
                                                new_conv_tensor, tensor_index,
                                                layer_index, block_index)

                        reset_indexes.append(reset_index)
                        if out_channels_keep_indexes is not None or len(
                                out_channels_keep_indexes) != 0:
                            in_channels_keep_indexes.append(
                                out_channels_keep_indexes[-1].sort()[0])

                        removed_filters_total_epoch += original_out_channels - keep_index.shape[
                            0]
                        if reset_index is not None:
                            reset_filters_total_epoch += len(reset_index)
                        parameters_hard_removed_per_epoch += (original_out_channels - keep_index.shape[0]) * \
                                                             out_channels_keep_indexes[-1].shape[0] * parameters.shape[
                                                                                                      2:].numel()
                        parameters_reset_removed += 0 if reset_index is None or len(
                            reset_index
                        ) == 0 else len(
                            reset_index) * out_channels_keep_indexes[-1].shape[
                                0] * parameters.shape[2:].numel()

                        out_channels_keep_indexes.append(keep_index.sort()[0])
                        if name not in model_architecture:
                            model_architecture[name] = []
                        model_architecture[name].append(keep_index.shape[0])

                elif param_type == ParameterType.DOWNSAMPLE_WEIGHTS:

                    last_start_conv, tensor_index, layer_index, block_index = last_start_conv
                    model_adapter.set_layer(model,
                                            ParameterType.DOWNSAMPLE_WEIGHTS,
                                            last_start_conv, tensor_index,
                                            layer_index, block_index)

                    keep_index, reset_index = start_index
                    reset_indexes.append(reset_index)
                    in_channels_keep_indexes.append(last_keep_index.sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.BN_WEIGHT:
                    bn_tensor = model_adapter.get_layer(
                        model, param_type, tensor_index, layer_index,
                        block_index)

                    keep_index = out_channels_keep_indexes[-1]
                    reset_index = reset_indexes[-1]

                    n_bn = create_new_bn(bn_tensor, keep_index, reset_index)
                    model_adapter.set_layer(model, param_type, n_bn,
                                            tensor_index, layer_index,
                                            block_index)
                    reset_indexes.append(reset_index)
                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.BN_BIAS:
                    reset_indexes.append(reset_indexes[-1])
                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.DOWNSAMPLE_BN_W:

                    bn_tensor = model_adapter.get_layer(
                        model, param_type, tensor_index, layer_index,
                        block_index)

                    keep_index, reset_index = start_index

                    n_bn = create_new_bn(bn_tensor, keep_index, reset_index)
                    model_adapter.set_layer(model, param_type, n_bn,
                                            tensor_index, layer_index,
                                            block_index)
                    reset_indexes.append(reset_index)
                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.DOWNSAMPLE_BN_B:
                    keep_index, reset_index = start_index
                    reset_indexes.append(reset_index)
                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.CNN_BIAS:
                    reset_indexes.append(reset_indexes[-1])
                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1])

        new_old_ids = {}
        new_ids = {}
        for k, v in model.named_parameters():
            new_id = id(v)
            new_ids[k] = new_id
            new_old_ids[new_id] = current_ids[k]

        o_state_dict = optimizer.state_dict()
        optimizer = optim.SGD(model.parameters(),
                              lr=optimizer.param_groups[0]["lr"],
                              momentum=optimizer.param_groups[0]["momentum"])
        n_new_state_dict = optimizer.state_dict()

        for k in n_new_state_dict["param_groups"][0]["params"]:
            old_id = new_old_ids[k]
            old_momentum = o_state_dict["state"][old_id]
            n_new_state_dict["state"][k] = old_momentum
        in_place_load_state_dict(optimizer, n_new_state_dict)

        index_op_dict = {}
        first_fc = False
        for i in range(len(type_list)):
            if type_list[i] == ParameterType.FC_WEIGHTS and first_fc == False:
                index_op_dict[optimizer.param_groups[0]['params'][i]] = (
                    type_list[i], out_channels_keep_indexes[i - 1], None, None)
                first_fc = True
            elif type_list[i] == ParameterType.FC_BIAS:
                continue
            elif type_list[i] == ParameterType.DOWNSAMPLE_BN_B or type_list[
                    i] == ParameterType.DOWNSAMPLE_BN_W or type_list[
                        i] == ParameterType.BN_BIAS or type_list[
                            i] == ParameterType.BN_WEIGHT:
                index_op_dict[optimizer.param_groups[0]['params'][i]] = (
                    type_list[i], out_channels_keep_indexes[i],
                    reset_indexes[i], None)
            else:
                index_op_dict[optimizer.param_groups[0]['params'][i]] = (
                    type_list[i], out_channels_keep_indexes[i],
                    reset_indexes[i], in_channels_keep_indexes[i])

        for k, v in index_op_dict.items():
            if v[0] == ParameterType.CNN_WEIGHTS or v[
                    0] == ParameterType.DOWNSAMPLE_WEIGHTS:
                if v[3] is not None and len(v[3]):
                    optimizer.state[k]["momentum_buffer"] = optimizer.state[k][
                        "momentum_buffer"][:, v[3], :, :]
                    if v[2] is not None:
                        optimizer.state[k]["momentum_buffer"][
                            v[2]] = initializer_fn(
                                optimizer.state[k]["momentum_buffer"][v[2]])
                optimizer.state[k]['momentum_buffer'] = optimizer.state[k][
                    'momentum_buffer'][v[1], :, :, :]

            elif v[0] == ParameterType.CNN_BIAS or v[0] == ParameterType.BN_WEIGHT or v[0] == ParameterType.BN_BIAS \
                    or v[0] == ParameterType.DOWNSAMPLE_BN_W or v[0] == ParameterType.DOWNSAMPLE_BN_B:
                if v[2] is not None:
                    optimizer.state[k]["momentum_buffer"][
                        v[2]] = initializer_fn(
                            optimizer.state[k]["momentum_buffer"][v[2]])
                optimizer.state[k]['momentum_buffer'] = optimizer.state[k][
                    'momentum_buffer'][v[1]]
            else:
                optimizer.state[k]['momentum_buffer'] = \
                    prune_fc_like(optimizer.state[k]['momentum_buffer'], v[1], original_out_channels)

        removed_filters_total += removed_filters_total_epoch
        parameters_hard_removed_total += parameters_hard_removed_per_epoch
        #("epoch {}: {}".format(epoch, removed_filters_total))

        acc = eval(model, cuda, test_loader, is_break)
        if logger is not None:
            logger.log_scalar(
                "rpgp_target_{}_after_target_val_acc".format(logger_id), acc,
                epoch)
            logger.log_scalar(
                "rpgp_target_{}_number of filter removed".format(logger_id),
                removed_filters_total + reset_filters_total_epoch, epoch)
            logger.log_scalar(
                "rpgp_target_{}_acc_number of filter removed".format(
                    logger_id), acc,
                removed_filters_total + reset_filters_total_epoch)
            logger.log_scalar(
                "rpgp_target_{}_acc_number of parameters removed".format(
                    logger_id), acc,
                parameters_hard_removed_total + parameters_reset_removed)

        loss_acc.append((total_loss / len(train_loader), acc))

    if final_fn[0]:
        optimizer.param_groups[0]["lr"] = final_fn[1]
        for epoch in range(epochs + 1, epochs + final_fn[0] + 1):
            if scheduler is not None:
                scheduler.step()
            model.train()

            # One epoch step gradient for target
            optimizer.zero_grad()

            start = time.clock()
            total_loss = train(model, optimizer, cuda, train_loader, is_break)
            end = time.clock()
            acc = eval(model, cuda, test_loader, is_break)

            if logger is not None:
                logger.log_scalar("rpgp_{}_epoch_time".format(logger_id),
                                  start - end, epoch)
                logger.log_scalar("rpgp_{}_training_loss".format(logger_id),
                                  total_loss, epoch)
                logger.log_scalar(
                    "rpgp_{}_before_target_val_acc".format(logger_id), acc,
                    epoch)

    return loss_acc, model_architecture
Esempio n. 3
0
def pgp_fasterRCNN(epochs, target_prune_rate, remove_ratio, criterion_func,
                   **kwargs):

    frcnn_extra = kwargs["frcnn_extra"]
    SHIFT_OPTI = 8
    FREEZE_FIRST_NUM = 10
    if frcnn_extra.net == "res101":
        SHIFT_OPTI = 8

    optimizer = kwargs["optimizer"]
    model = kwargs["model"]
    cuda = kwargs["cuda"]
    initializer_fn = kwargs["initializer_fn"]
    model_adapter = kwargs["model_adapter"]

    logger = kwargs["logger"]
    logger_id = ""
    if "logger_id" in kwargs:
        logger_id = kwargs["logger_id"]

    is_break = False
    if "is_break" in kwargs and kwargs["is_break"]:
        is_break = True

    till_break = False
    is_conservative = False
    if "is_conservative" in kwargs and kwargs["is_conservative"] is not None:
        is_conservative = kwargs["is_conservative"]
        till_break = True

    kwargs["train_loader"] = frcnn_extra.dataloader_train

    loss_acc = []
    type_list = []
    finished_list = False
    model_architecture = OrderedDict()
    removed_filters_total = 0
    forced_remove = False
    same_three = 0
    parameters_hard_removed_total = 0
    get_weak_fn = get_prune_index_target_with_reset
    lr = optimizer.param_groups[0]['lr']

    decay_rates_c = OrderedDict()
    original_c = OrderedDict()
    for name, parameters in model.named_parameters():
        param_type, tensor_index, layer_index, block_index = model_adapter.get_param_type_and_layer_index(
            name)
        if param_type == ParameterType.CNN_WEIGHTS or param_type == ParameterType.DOWNSAMPLE_WEIGHTS:
            decay_rates_c[
                name] = target_prune_rate * parameters.shape[0] / epochs
            original_c[name] = parameters.shape[0]
            model_architecture[name] = []

    for epoch in range(1, epochs + 1):

        start = time.clock()
        total_loss = train_frcnn(frcnn_extra, cuda, model, optimizer, is_break)
        end = time.clock()
        if logger is not None:
            logger.log_scalar(
                "pgp_target_frcnn_{}_epoch_time".format(logger_id),
                time.clock() - end, epoch)
            logger.log_scalar(
                "pgp_target_frcnn_{}_training_loss".format(logger_id),
                total_loss, epoch)

        if epoch % (frcnn_extra.lr_decay_step + 1) == 0:
            adjust_learning_rate(optimizer, frcnn_extra.lr_decay_gamma)
            lr *= frcnn_extra.lr_decay_gamma

        prune_index_dict, _ = criterion_func(**kwargs)

        out_channels_keep_indexes = []
        in_channels_keep_indexes = []
        reset_indexes = []
        original_out_channels = 0
        first_fc = False
        current_ids = OrderedDict()
        start_index = None
        last_start_conv = None
        last_keep_index = None
        removed_filters_total_epoch = 0
        reset_filters_total_epoch = 0
        parameters_hard_removed_per_epoch = 0
        parameters_reset_removed = 0

        # print(epoch)
        o_state_dict = optimizer.state_dict()
        for name, parameters in model.named_parameters():
            current_ids[name] = id(parameters)
            param_type, tensor_index, layer_index, block_index = model_adapter.get_param_type_and_layer_index(
                name)
            if not finished_list and parameters.requires_grad:
                type_list.append(param_type)

            if not parameters.requires_grad:
                continue

            if param_type is None:
                reset_indexes.append([])
                out_channels_keep_indexes.append([])
                in_channels_keep_indexes.append([])
                continue

            if layer_index == -1:
                # Handling CNN and BN before Resnet

                if tensor_index == model_adapter.last_layer_index:
                    if param_type == ParameterType.CNN_WEIGHTS:
                        original_out_channels = parameters.shape[0]
                        conv_tensor = model_adapter.get_layer(
                            model, param_type, tensor_index, layer_index,
                            block_index)

                        keep_index = torch.arange(
                            0, original_out_channels).long()
                        reset_index = []

                        new_conv_tensor = create_conv_tensor(
                            conv_tensor, out_channels_keep_indexes,
                            initializer_fn, keep_index, None).to(cuda)
                        model_adapter.set_layer(model, param_type,
                                                new_conv_tensor, tensor_index,
                                                layer_index, block_index)

                        in_c = parameters.shape[1]
                        if len(out_channels_keep_indexes) != 0:
                            in_c = out_channels_keep_indexes[-1].shape[0]

                        parameters_hard_removed_per_epoch += (original_out_channels - keep_index.shape[0]) * \
                                                             in_c * parameters.shape[2:].numel()
                        parameters_reset_removed += 0 if reset_index is None or len(
                            reset_index
                        ) == 0 else len(
                            reset_index) * in_c * parameters.shape[2:].numel()

                        reset_indexes.append(reset_index)
                        if out_channels_keep_indexes is not None and len(
                                out_channels_keep_indexes) != 0:
                            in_channels_keep_indexes.append(
                                out_channels_keep_indexes[-1].sort()[0])
                        else:
                            in_channels_keep_indexes.append(None)
                        out_channels_keep_indexes.append(keep_index.sort()[0])
                    elif param_type == ParameterType.CNN_BIAS:
                        reset_indexes.append(reset_indexes[-1])
                        in_channels_keep_indexes.append(
                            out_channels_keep_indexes[-1].sort()[0])
                        out_channels_keep_indexes.append(
                            out_channels_keep_indexes[-1])
                    continue

                if param_type == ParameterType.CNN_WEIGHTS:
                    original_out_channels = parameters.shape[0]
                    conv_tensor = model_adapter.get_layer(
                        model, param_type, tensor_index, layer_index,
                        block_index)

                    if name in prune_index_dict:
                        sorted_filters_index = prune_index_dict[name]
                        keep_index, reset_index = get_weak_fn(
                            original_out_channels,
                            0,
                            remove_ratio,
                            sorted_filters_index,
                            forced_remove,
                            original_c=original_c[name],
                            decay_rates_c=decay_rates_c[name],
                            epoch=epoch)
                        if reset_index is not None:
                            keep_index = torch.cat((keep_index, reset_index))
                    else:
                        keep_index = torch.arange(
                            0, original_out_channels).long()
                        reset_index = []

                    new_conv_tensor = create_conv_tensor(
                        conv_tensor, out_channels_keep_indexes, initializer_fn,
                        keep_index, reset_index).to(cuda)
                    model_adapter.set_layer(model, param_type, new_conv_tensor,
                                            tensor_index, layer_index,
                                            block_index)

                    if name not in model_architecture:
                        model_architecture[name] = []
                    model_architecture[name].append(keep_index.shape[0])

                    removed_filters_total_epoch += original_out_channels - keep_index.shape[
                        0]
                    reset_filters_total_epoch += len(reset_index)

                    in_c = 3
                    if len(out_channels_keep_indexes) != 0 and len(
                            out_channels_keep_indexes[-1]):
                        in_c = out_channels_keep_indexes[-1].shape[0]

                    parameters_hard_removed_per_epoch += (original_out_channels - keep_index.shape[0]) * \
                                                         in_c * parameters.shape[2:].numel()
                    parameters_reset_removed += 0 if reset_index is None or len(
                        reset_index) == 0 else len(
                            reset_index) * in_c * parameters.shape[2:].numel()

                    start_index = (keep_index.sort()[0], reset_index)

                    reset_indexes.append(reset_index)
                    if out_channels_keep_indexes is not None and len(
                            out_channels_keep_indexes) != 0 and len(
                                out_channels_keep_indexes[-1]):
                        in_channels_keep_indexes.append(
                            out_channels_keep_indexes[-1].sort()[0])
                    else:
                        in_channels_keep_indexes.append(None)
                    out_channels_keep_indexes.append(keep_index.sort()[0])
                elif param_type == ParameterType.CNN_BIAS:
                    reset_indexes.append(reset_indexes[-1])
                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1])

                elif param_type == ParameterType.BN_WEIGHT:
                    bn_tensor = model_adapter.get_layer(
                        model, param_type, tensor_index, layer_index,
                        block_index)

                    keep_index = out_channels_keep_indexes[-1]
                    reset_index = reset_indexes[-1]

                    n_bn = create_new_bn(bn_tensor, keep_index, reset_index)
                    model_adapter.set_layer(model, param_type, n_bn,
                                            tensor_index, layer_index,
                                            block_index)
                    reset_indexes.append(reset_index)
                    if out_channels_keep_indexes is not None or len(
                            out_channels_keep_indexes) != 0:
                        in_channels_keep_indexes.append(
                            out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.BN_BIAS:
                    reset_indexes.append(reset_index)
                    if out_channels_keep_indexes is not None or len(
                            out_channels_keep_indexes) != 0:
                        in_channels_keep_indexes.append(
                            out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.FC_WEIGHTS and first_fc == False:
                    fc_tensor = model_adapter.get_layer(
                        model, param_type, tensor_index, layer_index,
                        block_index)
                    new_fc_weight = prune_fc_like(
                        fc_tensor.weight.data, out_channels_keep_indexes[-1],
                        original_out_channels)

                    new_fc_bias = None
                    if fc_tensor.bias is not None:
                        new_fc_bias = fc_tensor.bias.data
                    new_fc_tensor = nn.Linear(new_fc_weight.shape[1],
                                              new_fc_weight.shape[0],
                                              bias=new_fc_bias
                                              is not None).to(cuda)
                    new_fc_tensor.weight.data = new_fc_weight
                    if fc_tensor.bias is not None:
                        new_fc_tensor.bias.data = new_fc_bias
                    model_adapter.set_layer(model, param_type, new_fc_tensor,
                                            tensor_index, layer_index,
                                            block_index)
                    first_fc = True
                    finished_list = True

            else:

                if param_type == ParameterType.CNN_WEIGHTS:

                    if tensor_index == 1:
                        original_out_channels = parameters.shape[0]
                        conv_tensor = model_adapter.get_layer(
                            model, param_type, tensor_index, layer_index,
                            block_index)

                        if name in prune_index_dict:
                            sorted_filters_index = prune_index_dict[name]
                            keep_index, reset_index = get_weak_fn(
                                original_out_channels,
                                0,
                                remove_ratio,
                                sorted_filters_index,
                                forced_remove,
                                original_c=original_c[name],
                                decay_rates_c=decay_rates_c[name],
                                epoch=epoch)
                            if reset_index is not None:
                                keep_index = torch.cat(
                                    (keep_index, reset_index))
                        else:
                            keep_index = torch.arange(
                                0, original_out_channels).long()
                            reset_index = []

                        new_conv_tensor = create_conv_tensor(
                            conv_tensor, out_channels_keep_indexes,
                            initializer_fn, keep_index, reset_index).to(cuda)
                        model_adapter.set_layer(model, param_type,
                                                new_conv_tensor, tensor_index,
                                                layer_index, block_index)

                        if name not in model_architecture:
                            model_architecture[name] = []
                        model_architecture[name].append(keep_index.shape[0])

                        removed_filters_total_epoch += original_out_channels - keep_index.shape[
                            0]
                        reset_filters_total_epoch += len(reset_index)

                        in_c = conv_tensor.in_channels
                        if len(out_channels_keep_indexes) != 0:
                            in_c = out_channels_keep_indexes[-1].shape[0]

                        parameters_hard_removed_per_epoch += (original_out_channels - keep_index.shape[0]) * \
                                                             in_c * parameters.shape[
                                                                    2:].numel()
                        parameters_reset_removed += 0 if reset_index is None or len(
                            reset_index
                        ) == 0 else len(
                            reset_index) * in_c * parameters.shape[2:].numel()

                        reset_indexes.append(reset_index)
                        if out_channels_keep_indexes is not None or len(
                                out_channels_keep_indexes) != 0:
                            in_channels_keep_indexes.append(
                                out_channels_keep_indexes[-1].sort()[0])
                        out_channels_keep_indexes.append(keep_index.sort()[0])

                    elif tensor_index == 2:

                        downsample_cnn, d_name = model_adapter.get_downsample(
                            model, layer_index, block_index)
                        if downsample_cnn is not None:
                            original_out_channels = parameters.shape[0]
                            last_keep_index, _ = start_index
                            if d_name in prune_index_dict:
                                sorted_filters_index = prune_index_dict[d_name]
                                # conv_tensor.out_channels

                                keep_index, reset_index = get_weak_fn(
                                    original_out_channels,
                                    0,
                                    remove_ratio,
                                    sorted_filters_index,
                                    forced_remove,
                                    original_c=original_c[d_name],
                                    decay_rates_c=decay_rates_c[d_name],
                                    epoch=epoch)

                                if reset_index is not None:
                                    keep_index = torch.cat(
                                        (keep_index, reset_index))
                            else:
                                keep_index = torch.arange(
                                    0, original_out_channels).long()
                                reset_index = []

                            last_start_conv = create_conv_tensor(
                                downsample_cnn, [last_keep_index],
                                initializer_fn, keep_index,
                                reset_index).to(cuda)
                            last_start_conv = [
                                last_start_conv, 0, layer_index, block_index
                            ]

                            if d_name not in model_architecture:
                                model_architecture[d_name] = []
                            model_architecture[d_name].append(
                                keep_index.shape[0])

                            removed_filters_total_epoch += original_out_channels - keep_index.shape[
                                0]
                            reset_filters_total_epoch += len(reset_index)
                            parameters_hard_removed_per_epoch += (original_out_channels - keep_index.shape[0]) * \
                                                                 last_keep_index.shape[0] * parameters.shape[2:].numel()
                            parameters_reset_removed += 0 if reset_index is None or len(
                                reset_index) == 0 else len(
                                    reset_index) * last_keep_index.shape[
                                        0] * parameters.shape[2:].numel()
                            start_index = (keep_index.sort()[0], reset_index)

                        conv_tensor = model_adapter.get_layer(
                            model, param_type, tensor_index, layer_index,
                            block_index)
                        keep_index, reset_index = start_index

                        new_conv_tensor = create_conv_tensor(
                            conv_tensor, out_channels_keep_indexes,
                            initializer_fn, keep_index, reset_index).to(cuda)
                        model_adapter.set_layer(model, param_type,
                                                new_conv_tensor, tensor_index,
                                                layer_index, block_index)

                        reset_indexes.append(reset_index)
                        if out_channels_keep_indexes is not None or len(
                                out_channels_keep_indexes) != 0:
                            in_channels_keep_indexes.append(
                                out_channels_keep_indexes[-1].sort()[0])

                        removed_filters_total_epoch += original_out_channels - keep_index.shape[
                            0]
                        reset_filters_total_epoch += len(reset_index)
                        parameters_hard_removed_per_epoch += (original_out_channels - keep_index.shape[0]) * \
                                                             out_channels_keep_indexes[-1].shape[0] * parameters.shape[
                                                                                                      2:].numel()
                        parameters_reset_removed += 0 if reset_index is None or len(
                            reset_index
                        ) == 0 else len(
                            reset_index) * out_channels_keep_indexes[-1].shape[
                                0] * parameters.shape[2:].numel()

                        out_channels_keep_indexes.append(keep_index.sort()[0])
                        if name not in model_architecture:
                            model_architecture[name] = []
                        model_architecture[name].append(keep_index.shape[0])

                elif param_type == ParameterType.DOWNSAMPLE_WEIGHTS:

                    last_start_conv, tensor_index, layer_index, block_index = last_start_conv
                    model_adapter.set_layer(model,
                                            ParameterType.DOWNSAMPLE_WEIGHTS,
                                            last_start_conv, tensor_index,
                                            layer_index, block_index)

                    keep_index, reset_index = start_index
                    reset_indexes.append(reset_index)
                    in_channels_keep_indexes.append(last_keep_index.sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.BN_WEIGHT:
                    bn_tensor = model_adapter.get_layer(
                        model, param_type, tensor_index, layer_index,
                        block_index)

                    keep_index = out_channels_keep_indexes[-1]
                    reset_index = reset_indexes[-1]

                    n_bn = create_new_bn(bn_tensor, keep_index, reset_index)
                    model_adapter.set_layer(model, param_type, n_bn,
                                            tensor_index, layer_index,
                                            block_index)
                    reset_indexes.append(reset_index)
                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.BN_BIAS:
                    reset_indexes.append(reset_indexes[-1])
                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.DOWNSAMPLE_BN_W:

                    bn_tensor = model_adapter.get_layer(
                        model, param_type, tensor_index, layer_index,
                        block_index)

                    keep_index, reset_index = start_index

                    n_bn = create_new_bn(bn_tensor, keep_index, reset_index)
                    model_adapter.set_layer(model, param_type, n_bn,
                                            tensor_index, layer_index,
                                            block_index)
                    reset_indexes.append(reset_index)
                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.DOWNSAMPLE_BN_B:
                    keep_index, reset_index = start_index
                    reset_indexes.append(reset_index)
                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.CNN_BIAS:
                    reset_indexes.append(reset_indexes[-1])
                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1])

        finished_list = True
        new_old_ids = OrderedDict()
        new_ids = OrderedDict()
        for k, v in model.named_parameters():
            if v.requires_grad:
                new_id = id(v)
                new_ids[k] = new_id
                new_old_ids[new_id] = current_ids[k]

        for layer in range(10):
            for p in model.RCNN_base[layer].parameters():
                p.requires_grad = False

        params = []
        for key, value in dict(model.named_parameters()).items():
            if value.requires_grad:
                if 'bias' in key:
                    params += [{'params': [value], 'lr': lr * (cfg.TRAIN.DOUBLE_BIAS + 1), \
                                'weight_decay': cfg.TRAIN.BIAS_DECAY and cfg.TRAIN.WEIGHT_DECAY or 0}]
                else:
                    params += [{
                        'params': [value],
                        'lr': lr,
                        'weight_decay': cfg.TRAIN.WEIGHT_DECAY
                    }]

        optimizer = optim.SGD(params,
                              lr=optimizer.param_groups[0]["lr"],
                              momentum=optimizer.param_groups[0]["momentum"])
        n_new_state_dict = optimizer.state_dict()

        for i, k in enumerate(n_new_state_dict["param_groups"]):

            old_id = new_old_ids[k['params'][0]]
            old_momentum = o_state_dict["state"][old_id]
            n_new_state_dict["state"][k['params'][0]] = old_momentum
        in_place_load_state_dict(optimizer, n_new_state_dict)

        index_op_dict = OrderedDict()
        first_fc = False
        #type_list = [x for x in type_list if x is not None]
        for i in range(len(type_list)):
            if type_list[i] == ParameterType.FC_WEIGHTS and first_fc == False:
                index_op_dict[optimizer.param_groups[i]['params'][0]] = (
                    type_list[i], out_channels_keep_indexes[i - 1], None, None)
                first_fc = True
            elif type_list[i] == ParameterType.FC_BIAS:
                continue
            elif type_list[i] == ParameterType.DOWNSAMPLE_BN_B or type_list[i] == ParameterType.DOWNSAMPLE_BN_W or \
                    type_list[i] == ParameterType.BN_BIAS or type_list[i] == ParameterType.BN_WEIGHT:
                index_op_dict[optimizer.param_groups[i]['params'][0]] = (
                    type_list[i], out_channels_keep_indexes[i],
                    reset_indexes[i], None)
            elif type_list[i] is None:
                continue
            elif type_list[i] == ParameterType.CNN_WEIGHTS or type_list[
                    i] == ParameterType.DOWNSAMPLE_WEIGHTS or type_list[
                        i] == ParameterType.CNN_BIAS or type_list == ParameterType.DOWNSAMPLE_BIAS:
                index_op_dict[optimizer.param_groups[i]['params'][0]] = (
                    type_list[i], out_channels_keep_indexes[i],
                    reset_indexes[i], in_channels_keep_indexes[i])

        j = 0
        for k, v in index_op_dict.items():

            if v[0] == ParameterType.CNN_WEIGHTS or v[
                    0] == ParameterType.DOWNSAMPLE_WEIGHTS:
                if v[3] is not None and len(v[3]):
                    optimizer.state[k]["momentum_buffer"] = optimizer.state[k][
                        "momentum_buffer"][:, v[3], :, :]
                    if v[2] is not None:
                        optimizer.state[k]["momentum_buffer"][
                            v[2]] = initializer_fn(
                                optimizer.state[k]["momentum_buffer"][v[2]])
                optimizer.state[k]['momentum_buffer'] = optimizer.state[k][
                    'momentum_buffer'][v[1], :, :, :]

            elif v[0] == ParameterType.CNN_BIAS or v[0] == ParameterType.BN_WEIGHT or v[0] == ParameterType.BN_BIAS \
                    or v[0] == ParameterType.DOWNSAMPLE_BN_W or v[0] == ParameterType.DOWNSAMPLE_BN_B:
                if v[2] is not None:
                    optimizer.state[k]["momentum_buffer"][
                        v[2]] = initializer_fn(
                            optimizer.state[k]["momentum_buffer"][v[2]])
                optimizer.state[k]['momentum_buffer'] = optimizer.state[k][
                    'momentum_buffer'][v[1]]
            else:
                optimizer.state[k]['momentum_buffer'] = \
                    prune_fc_like(optimizer.state[k]['momentum_buffer'], v[1], original_out_channels)
            j += 1
        removed_filters_total += removed_filters_total_epoch
        parameters_hard_removed_total += parameters_hard_removed_per_epoch

        map = eval_frcnn(frcnn_extra, cuda, model, is_break)
        if logger is not None:
            logger.log_scalar(
                "pgp_target_frcnn_{}_after_target_val_acc".format(logger_id),
                map, epoch)
            logger.log_scalar(
                "pgp_target_frcnn_{}_number of filter removed".format(
                    logger_id),
                removed_filters_total + reset_filters_total_epoch, epoch)
            logger.log_scalar(
                "pgp_target_frcnn_{}_acc_number of filter removed".format(
                    logger_id), map,
                removed_filters_total + reset_filters_total_epoch)
            logger.log_scalar(
                "pgp_target_frcnn_{}_acc_number of parameters removed".format(
                    logger_id), map,
                parameters_hard_removed_total + parameters_reset_removed)
        torch.cuda.empty_cache()

    return loss_acc, model_architecture
Esempio n. 4
0
def iterative_pruning(epochs_fn, target_prune, prune_each_steps, test_loader,
                      criterion_func, **kwargs):

    model = kwargs["model"]
    cuda = kwargs["cuda"]
    train_loader = kwargs["train_loader"]
    train_ratio = kwargs["train_ratio"]
    model_adapter = kwargs["model_adapter"]
    is_break = kwargs["is_break"]
    loss_acc = []
    type_list = []
    finished_list = False
    model_architecture = {}
    optimizer = kwargs["optimizer"]

    key_pts = np.asarray([0.3, 0.5, 0.7, 0.9])

    ki = 0

    logger = kwargs["logger"]

    if not "logger_id" in kwargs:
        logger_id = ""
    else:
        logger_id = kwargs["logger_id"]

    original_num_filers = get_original_num_of_filters(model, model_adapter)
    target_num_prune = int(original_num_filers * target_prune /
                           prune_each_steps)
    key_pts = key_pts * original_num_filers
    removed_filter_total = 0
    removed_parameters_total = 0

    for num_of_removed in range(1, target_num_prune + 1):

        prune_index_dict, values_indexes = criterion_func(**kwargs)

        out_channels_keep_indexes = []
        in_channels_keep_indexes = []

        original_out_channels = 0
        first_fc = False
        current_ids = {}
        start_index = None
        last_start_conv = None
        last_keep_index = None
        weakest_by_name = {}
        weakest_val = 999999999
        list_name_param = []

        weakest_list = []
        for k, v in values_indexes.items():
            param_type, tensor_index, layer_index, block_index = model_adapter.get_param_type_and_layer_index(
                k)
            if len(v.shape) == 0 or v.shape[0] == 1 or (tensor_index == 2
                                                        and layer_index != -1):
                continue

            limit = prune_each_steps if v.shape[
                0] > prune_each_steps else v.shape[0]

            for i in range(v[:limit].shape[0]):
                if param_type == ParameterType.DOWNSAMPLE_WEIGHTS:
                    _, conv2_name = model_adapter.get_conv2_from_downsample(
                        model, layer_index, block_index)
                    weakest_list = insert_sort_list(weakest_list, v[i],
                                                    conv2_name,
                                                    prune_index_dict[k][i],
                                                    True, prune_each_steps)
                else:
                    weakest_list = insert_sort_list(weakest_list, v[i], k,
                                                    prune_index_dict[k][i],
                                                    False, prune_each_steps)

        prune_name_dict = finalize_weakest_list(weakest_list, model,
                                                model_adapter)

        for name, parameters in model.named_parameters():

            current_ids[name] = id(parameters)
            param_type, tensor_index, layer_index, block_index = model_adapter.get_param_type_and_layer_index(
                name)
            if not finished_list:
                type_list.append(param_type)
                list_name_param.append(name)

            if layer_index == -1:
                # Handling CNN and BN before Resnet
                if param_type == ParameterType.CNN_WEIGHTS:

                    conv_tensor = model_adapter.get_layer(
                        model, param_type, tensor_index, layer_index,
                        block_index)
                    original_out_channels = parameters.shape[
                        0]  # conv_tensor.out_channels

                    reset_index = None
                    if not prune_name_dict is None and name in prune_name_dict:
                        _, filtered_index, _ = prune_name_dict[name]
                        original_index = torch.arange(0, original_out_channels)
                        mask = torch.ones_like(original_index)
                        mask[filtered_index] = 0
                        keep_index = original_index[mask.nonzero()].squeeze()
                        if original_index.shape[0] - filtered_index.shape[
                                0] == 0:
                            keep_index = torch.arange(
                                0, original_out_channels).long()
                        elif original_index.shape[0] - filtered_index.shape[
                                0] == 1:
                            keep_index = torch.LongTensor([keep_index])
                    else:
                        keep_index = torch.arange(
                            0, original_out_channels).long()

                    new_conv_tensor = create_conv_tensor(
                        conv_tensor, out_channels_keep_indexes, None,
                        keep_index, reset_index).to(cuda)
                    model_adapter.set_layer(model, param_type, new_conv_tensor,
                                            tensor_index, layer_index,
                                            block_index)

                    if name not in model_architecture:
                        model_architecture[name] = []
                    model_architecture[name].append(keep_index.shape[0])
                    start_index = (keep_index.sort()[0], reset_index)

                    in_c = conv_tensor.in_channels
                    if len(out_channels_keep_indexes) != 0:
                        in_c = out_channels_keep_indexes[-1].shape[0]

                    removed_parameters_total += (original_out_channels - keep_index.shape[0]) * \
                                                in_c * parameters[2:].numel()
                    removed_filter_total += original_out_channels - keep_index.shape[
                        0]

                    if out_channels_keep_indexes is not None and len(
                            out_channels_keep_indexes) != 0:
                        in_channels_keep_indexes.append(
                            out_channels_keep_indexes[-1].sort()[0])
                    else:
                        in_channels_keep_indexes.append(None)
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.CNN_BIAS:
                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1])

                elif param_type == ParameterType.BN_WEIGHT:
                    bn_tensor = model_adapter.get_layer(
                        model, param_type, tensor_index, layer_index,
                        block_index)

                    keep_index = out_channels_keep_indexes[-1]

                    n_bn = create_new_bn(bn_tensor, keep_index, None)
                    model_adapter.set_layer(model, param_type, n_bn,
                                            tensor_index, layer_index,
                                            block_index)
                    del bn_tensor
                    torch.cuda.empty_cache()

                    if out_channels_keep_indexes is not None or len(
                            out_channels_keep_indexes) != 0:
                        in_channels_keep_indexes.append(
                            out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.BN_BIAS:

                    if out_channels_keep_indexes is not None or len(
                            out_channels_keep_indexes) != 0:
                        in_channels_keep_indexes.append(
                            out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.FC_WEIGHTS and first_fc == False:
                    fc_tensor = model_adapter.get_layer(
                        model, param_type, tensor_index, layer_index,
                        block_index)
                    new_fc_weight = prune_fc_like(
                        fc_tensor.weight.data, out_channels_keep_indexes[-1],
                        original_out_channels)

                    new_fc_bias = None
                    if fc_tensor.bias is not None:
                        new_fc_bias = fc_tensor.bias.data
                    new_fc_tensor = nn.Linear(new_fc_weight.shape[1],
                                              new_fc_weight.shape[0],
                                              bias=new_fc_bias
                                              is not None).to(cuda)
                    new_fc_tensor.weight.data = new_fc_weight
                    if fc_tensor.bias is not None:
                        new_fc_tensor.bias.data = new_fc_bias
                    model_adapter.set_layer(model, param_type, new_fc_tensor,
                                            tensor_index, layer_index,
                                            block_index)

                    del fc_tensor
                    torch.cuda.empty_cache()

                    first_fc = True
                    finished_list = True

            else:

                if param_type == ParameterType.CNN_WEIGHTS:

                    if tensor_index == 1:
                        conv_tensor = model_adapter.get_layer(
                            model, param_type, tensor_index, layer_index,
                            block_index)
                        original_out_channels = parameters.shape[
                            0]  # conv_tensor.out_channels

                        reset_index = None
                        if not prune_name_dict is None and name in prune_name_dict:
                            _, filtered_index, _ = prune_name_dict[name]
                            original_index = torch.arange(
                                0, original_out_channels)
                            mask = torch.ones_like(original_index)
                            mask[filtered_index] = 0
                            keep_index = original_index[
                                mask.nonzero()].squeeze()
                            if original_index.shape[0] - filtered_index.shape[
                                    0] == 0:
                                keep_index = torch.arange(
                                    0, original_out_channels).long()
                            elif original_index.shape[
                                    0] - filtered_index.shape[0] == 1:
                                keep_index = torch.LongTensor([keep_index])

                        else:
                            keep_index = torch.arange(
                                0, original_out_channels).long()

                        new_conv_tensor = create_conv_tensor(
                            conv_tensor, out_channels_keep_indexes, None,
                            keep_index, reset_index).to(cuda)
                        model_adapter.set_layer(model, param_type,
                                                new_conv_tensor, tensor_index,
                                                layer_index, block_index)

                        if name not in model_architecture:
                            model_architecture[name] = []
                        model_architecture[name].append(keep_index.shape[0])

                        in_c = conv_tensor.in_channels
                        if len(out_channels_keep_indexes) != 0:
                            in_c = out_channels_keep_indexes[-1].shape[0]
                        removed_parameters_total += (original_out_channels - keep_index.shape[0]) * \
                                                    in_c * parameters[2:].numel()
                        removed_filter_total += original_out_channels - keep_index.shape[
                            0]

                        if out_channels_keep_indexes is not None and len(
                                out_channels_keep_indexes) != 0:
                            in_channels_keep_indexes.append(
                                out_channels_keep_indexes[-1].sort()[0])
                        out_channels_keep_indexes.append(keep_index.sort()[0])

                    elif tensor_index == 2:

                        downsample_cnn, d_name = model_adapter.get_downsample(
                            model, layer_index, block_index)

                        if downsample_cnn is not None:
                            original_out_channels = parameters.shape[
                                0]  # conv_tensor.out_channels

                            last_keep_index, _ = start_index

                            reset_index = None
                            if not prune_name_dict is None and name in prune_name_dict and prune_name_dict[
                                    name][2]:
                                _, filtered_index, _ = prune_name_dict[name]
                                original_index = torch.arange(
                                    0, original_out_channels)
                                mask = torch.ones_like(original_index)
                                mask[filtered_index] = 0
                                keep_index = original_index[
                                    mask.nonzero()].squeeze()
                                if original_index.shape[
                                        0] - filtered_index.shape[0] == 0:
                                    keep_index = torch.arange(
                                        0, original_out_channels).long()
                                elif original_index.shape[
                                        0] - filtered_index.shape[0] == 1:
                                    keep_index = torch.LongTensor([keep_index])
                            else:
                                keep_index = torch.arange(
                                    0, original_out_channels).long()

                            last_start_conv = create_conv_tensor(
                                downsample_cnn, [last_keep_index], None,
                                keep_index, reset_index).to(cuda)
                            last_start_conv = [
                                last_start_conv, 0, layer_index, block_index
                            ]

                            if d_name not in model_architecture:
                                model_architecture[d_name] = []
                            model_architecture[d_name].append(
                                keep_index.shape[0])
                            start_index = (keep_index.sort()[0], reset_index)
                            removed_parameters_total += (original_out_channels - keep_index.shape[0]) * \
                                                        last_keep_index.shape[0] * parameters[2:].numel()
                            removed_filter_total += original_out_channels - keep_index.shape[
                                0]

                        original_out_channels = parameters.shape[0]
                        conv_tensor = model_adapter.get_layer(
                            model, param_type, tensor_index, layer_index,
                            block_index)
                        keep_index, reset_index = start_index

                        new_conv_tensor = create_conv_tensor(
                            conv_tensor, out_channels_keep_indexes, None,
                            keep_index, reset_index).to(cuda)
                        model_adapter.set_layer(model, param_type,
                                                new_conv_tensor, tensor_index,
                                                layer_index, block_index)

                        if out_channels_keep_indexes is not None and len(
                                out_channels_keep_indexes) != 0:
                            in_channels_keep_indexes.append(
                                out_channels_keep_indexes[-1].sort()[0])

                        removed_parameters_total += (original_out_channels - keep_index.shape[0]) * \
                                                    out_channels_keep_indexes[-1].shape[0] * parameters[2:].numel()
                        removed_filter_total += original_out_channels - keep_index.shape[
                            0]

                        out_channels_keep_indexes.append(keep_index.sort()[0])
                        if name not in model_architecture:
                            model_architecture[name] = []
                        model_architecture[name].append(keep_index.shape[0])

                elif param_type == ParameterType.DOWNSAMPLE_WEIGHTS:

                    last_start_conv, tensor_index, layer_index, block_index = last_start_conv
                    model_adapter.set_layer(model,
                                            ParameterType.DOWNSAMPLE_WEIGHTS,
                                            last_start_conv, tensor_index,
                                            layer_index, block_index)

                    keep_index, reset_index = start_index

                    in_channels_keep_indexes.append(last_keep_index.sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.BN_WEIGHT:
                    bn_tensor = model_adapter.get_layer(
                        model, param_type, tensor_index, layer_index,
                        block_index)

                    keep_index = out_channels_keep_indexes[-1]

                    n_bn = create_new_bn(bn_tensor, keep_index, reset_index)
                    model_adapter.set_layer(model, param_type, n_bn,
                                            tensor_index, layer_index,
                                            block_index)

                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.BN_BIAS:
                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.DOWNSAMPLE_BN_W:

                    bn_tensor = model_adapter.get_layer(
                        model, param_type, tensor_index, layer_index,
                        block_index)

                    keep_index, reset_index = start_index

                    n_bn = create_new_bn(bn_tensor, keep_index, reset_index)
                    model_adapter.set_layer(model, param_type, n_bn,
                                            tensor_index, layer_index,
                                            block_index)

                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.DOWNSAMPLE_BN_B:
                    keep_index, reset_index = start_index

                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif param_type == ParameterType.CNN_BIAS:

                    in_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(
                        out_channels_keep_indexes[-1])

        if num_of_removed > 1:
            new_old_ids = {}
            new_ids = {}
            for k, v in model.named_parameters():
                new_id = id(v)
                new_ids[k] = new_id
                new_old_ids[new_id] = current_ids[k]

            o_state_dict = optimizer.state_dict()
            optimizer = optim.SGD(
                model.parameters(),
                lr=optimizer.param_groups[0]["lr"],
                momentum=optimizer.param_groups[0]["momentum"])
            n_new_state_dict = optimizer.state_dict()

            for k in n_new_state_dict["param_groups"][0]["params"]:
                old_id = new_old_ids[k]
                old_momentum = o_state_dict["state"][old_id]
                n_new_state_dict["state"][k] = old_momentum
            in_place_load_state_dict(optimizer, n_new_state_dict)

            index_op_dict = {}
            first_fc = False
            for i in range(len(type_list)):
                if type_list[
                        i] == ParameterType.FC_WEIGHTS and first_fc == False:
                    index_op_dict[optimizer.param_groups[0]['params'][i]] = (
                        type_list[i], out_channels_keep_indexes[i - 1], None,
                        None)
                    first_fc = True
                elif type_list[i] == ParameterType.FC_BIAS:
                    continue
                elif type_list[i] == ParameterType.DOWNSAMPLE_BN_B or type_list[i] == ParameterType.DOWNSAMPLE_BN_W or \
                        type_list[i] == ParameterType.BN_BIAS or type_list[i] == ParameterType.BN_WEIGHT:
                    index_op_dict[optimizer.param_groups[0]['params'][i]] = (
                        type_list[i], out_channels_keep_indexes[i], None, None)
                else:
                    index_op_dict[optimizer.param_groups[0]['params'][i]] = (
                        type_list[i], out_channels_keep_indexes[i], None,
                        in_channels_keep_indexes[i])

            for k, v in index_op_dict.items():
                if v[0] == ParameterType.CNN_WEIGHTS or v[
                        0] == ParameterType.DOWNSAMPLE_WEIGHTS:
                    if v[3] is not None:
                        optimizer.state[k][
                            "momentum_buffer"] = optimizer.state[k][
                                "momentum_buffer"][:, v[3], :, :]

                    optimizer.state[k]['momentum_buffer'] = optimizer.state[k][
                        'momentum_buffer'][v[1], :, :, :]

                elif v[0] == ParameterType.CNN_BIAS or v[0] == ParameterType.BN_WEIGHT or v[0] == ParameterType.BN_BIAS \
                        or v[0] == ParameterType.DOWNSAMPLE_BN_W or v[0] == ParameterType.DOWNSAMPLE_BN_B:

                    optimizer.state[k]['momentum_buffer'] = optimizer.state[k][
                        'momentum_buffer'][v[1]]
                else:
                    optimizer.state[k]['momentum_buffer'] = \
                        prune_fc_like(optimizer.state[k]['momentum_buffer'], v[1], original_out_channels)

        if num_of_removed == 1:
            optimizer = optim.SGD(
                model.parameters(),
                lr=optimizer.param_groups[0]['lr'],
                momentum=optimizer.param_groups[0]['momentum'])
            torch.cuda.empty_cache()
        for epoch in range(1, epochs_fn):
            model.train()
            optimizer.zero_grad()
            total_loss = train(model, optimizer, cuda, train_loader, is_break)
        acc = eval(model, cuda, test_loader, is_break)

        if logger is not None:
            logger.log_scalar(
                "iterative_pruning_{}_after_target_val_acc".format(logger_id),
                acc, num_of_removed)
            logger.log_scalar("iterative_pruning_{}_number of filter removed",
                              removed_filter_total, epoch)
            logger.log_scalar(
                "iterative_pruning_{}_acc_number of filter removed".format(
                    logger_id), acc, removed_filter_total)
            logger.log_scalar(
                "iterative_pruning_{}_acc_parameters_removed".format(
                    logger_id), acc, removed_parameters_total)

            #print("{}:{}:{}".format(removed_filter_total, key_pts[ki], original_num_filers))
            if ki < key_pts.shape[0] and removed_filter_total >= key_pts[ki]:
                flops, params = profile(
                    model,
                    input_size=train_loader.dataset[0][0].unsqueeze(0).shape)
                logger.log_scalar(
                    "iterative_pruning_{}_flops_counts".format(logger_id),
                    flops, key_pts[ki])
                logger.log_scalar(
                    "iterative_pruning_{}_params_counts".format(logger_id),
                    params, key_pts[ki])
                ki += 1

            loss_acc.append((total_loss / len(train_loader), acc))

        if removed_filter_total > target_prune * original_num_filers:
            print("{}/{}".format(removed_filter_total, original_num_filers))
            break
        #print("{}:{}/{}".format(num_of_removed,removed_total, original_num_filers))

    flops, params = profile(
        model, input_size=train_loader.dataset[0][0].unsqueeze(0).shape)
    logger.log_scalar("iterative_pruning_{}_flops_counts".format(logger_id),
                      flops, key_pts[-1])
    logger.log_scalar("iterative_pruning_{}_params_counts".format(logger_id),
                      params, key_pts[-1])
    return loss_acc, model_architecture
Esempio n. 5
0
def prune_once(epochs_fn, optimizer, target_prune_remains, test_loader,
               criterion_func, **kwargs):

    model = kwargs["model"]
    cuda = kwargs["cuda"]
    train_loader = kwargs["train_loader"]
    train_ratio = kwargs["train_ratio"]
    model_adapter = kwargs["model_adapter"]
    is_break = kwargs["is_break"]
    loss_acc = []
    type_list = []
    finished_list = False
    model_architecture = {}

    logger = kwargs["logger"]

    if not "logger_id" in kwargs:
        logger_id = ""
    else:
        logger_id = kwargs["logger_id"]

    before_epochs = 0
    prune_index_dict, _ = criterion_func(**kwargs)
    out_channels_keep_indexes = []
    in_channels_keep_indexes = []

    original_out_channels = 0
    first_fc = False
    current_ids = {}
    start_index = None
    last_start_conv = None
    last_keep_index = None
    removed_filters_total_epoch = 0
    removed_parameters_total = 0


    for name, parameters in model.named_parameters():

        current_ids[name] = id(parameters)
        param_type, tensor_index, layer_index, block_index = model_adapter.get_param_type_and_layer_index(name)
        if not finished_list:
            type_list.append(param_type)
        if layer_index == -1:
            # Handling CNN and BN before Resnet
            if param_type == ParameterType.CNN_WEIGHTS:
                sorted_filters_index = prune_index_dict[name]
                conv_tensor = model_adapter.get_layer(model, param_type, tensor_index, layer_index, block_index)
                original_out_channels = parameters.shape[0]  # conv_tensor.out_channels

                keep_index, reset_index = get_prune_index_target(original_out_channels, target_prune_remains[name],
                                                                 sorted_filters_index)

                new_conv_tensor = create_conv_tensor(conv_tensor, out_channels_keep_indexes, None,
                                                     keep_index, reset_index).to(cuda)
                model_adapter.set_layer(model, param_type, new_conv_tensor, tensor_index, layer_index, block_index)

                if name not in model_architecture:
                    model_architecture[name] = []
                model_architecture[name].append(keep_index.shape[0])
                removed_filters_total_epoch += original_out_channels - keep_index.shape[0]

                in_c = 3
                if len(out_channels_keep_indexes) != 0:
                    in_c = out_channels_keep_indexes[-1].shape[0]

                removed_parameters_total += (original_out_channels - keep_index.shape[0]) * \
                                            in_c * parameters.shape[2:].numel()
                start_index = (keep_index.sort()[0], reset_index)

                if out_channels_keep_indexes is not None and len(out_channels_keep_indexes) != 0:
                    in_channels_keep_indexes.append(out_channels_keep_indexes[-1].sort()[0])
                else:
                    in_channels_keep_indexes.append(None)
                out_channels_keep_indexes.append(keep_index.sort()[0])

            elif param_type == ParameterType.CNN_BIAS:
                in_channels_keep_indexes.append(out_channels_keep_indexes[-1].sort()[0])
                out_channels_keep_indexes.append(out_channels_keep_indexes[-1])

            elif param_type == ParameterType.BN_WEIGHT:
                bn_tensor = model_adapter.get_layer(model, param_type, tensor_index, layer_index, block_index)

                keep_index = out_channels_keep_indexes[-1]


                n_bn = create_new_bn(bn_tensor, keep_index, None)
                model_adapter.set_layer(model, param_type, n_bn, tensor_index, layer_index, block_index)

                if out_channels_keep_indexes is not None or len(out_channels_keep_indexes) != 0:
                    in_channels_keep_indexes.append(out_channels_keep_indexes[-1].sort()[0])
                out_channels_keep_indexes.append(keep_index.sort()[0])

            elif param_type == ParameterType.BN_BIAS:

                if out_channels_keep_indexes is not None or len(out_channels_keep_indexes) != 0:
                    in_channels_keep_indexes.append(out_channels_keep_indexes[-1].sort()[0])
                out_channels_keep_indexes.append(keep_index.sort()[0])

            elif param_type == ParameterType.FC_WEIGHTS and first_fc == False:
                fc_tensor = model_adapter.get_layer(model, param_type, tensor_index, layer_index, block_index)
                new_fc_weight = prune_fc_like(fc_tensor.weight.data, out_channels_keep_indexes[-1],
                                              original_out_channels)

                new_fc_bias = None
                if fc_tensor.bias is not None:
                    new_fc_bias = fc_tensor.bias.data
                new_fc_tensor = nn.Linear(new_fc_weight.shape[1], new_fc_weight.shape[0],
                                          bias=new_fc_bias is not None).to(cuda)
                new_fc_tensor.weight.data = new_fc_weight
                if fc_tensor.bias is not None:
                    new_fc_tensor.bias.data = new_fc_bias
                model_adapter.set_layer(model, param_type, new_fc_tensor, tensor_index, layer_index, block_index)
                first_fc = True
                finished_list = True

        else:

            if param_type == ParameterType.CNN_WEIGHTS:

                if tensor_index == 1:
                    sorted_filters_index = prune_index_dict[name]
                    conv_tensor = model_adapter.get_layer(model, param_type, tensor_index, layer_index, block_index)
                    original_out_channels = parameters.shape[0]  # conv_tensor.out_channels
                    keep_index, reset_index = get_prune_index_target(original_out_channels, target_prune_remains[name],
                                                                     sorted_filters_index)


                    new_conv_tensor = create_conv_tensor(conv_tensor, out_channels_keep_indexes, None,
                                                         keep_index, reset_index).to(cuda)
                    model_adapter.set_layer(model, param_type, new_conv_tensor, tensor_index, layer_index,
                                            block_index)

                    if name not in model_architecture:
                        model_architecture[name] = []
                    model_architecture[name].append(keep_index.shape[0])
                    removed_filters_total_epoch += original_out_channels - keep_index.shape[0]

                    in_c = conv_tensor.in_channels
                    if len(out_channels_keep_indexes) != 0:
                        in_c = out_channels_keep_indexes[-1].shape[0]

                    removed_parameters_total += (original_out_channels - keep_index.shape[0]) * \
                                                in_c * parameters.shape[2:].numel()

                    if out_channels_keep_indexes is not None or len(out_channels_keep_indexes) != 0:
                        in_channels_keep_indexes.append(out_channels_keep_indexes[-1].sort()[0])
                    out_channels_keep_indexes.append(keep_index.sort()[0])

                elif tensor_index == 2:

                    downsample_cnn, d_name = model_adapter.get_downsample(model, layer_index, block_index)
                    if downsample_cnn is not None:

                        sorted_filters_index = prune_index_dict[d_name]
                        original_out_channels = parameters.shape[0]  # conv_tensor.out_channels
                        last_keep_index, _ = start_index

                        keep_index, reset_index = get_prune_index_target(original_out_channels, target_prune_remains[d_name],
                                                                         sorted_filters_index)

                        last_start_conv = create_conv_tensor(downsample_cnn, [last_keep_index], None,
                                                             keep_index, reset_index).to(cuda)
                        last_start_conv = [last_start_conv, 0, layer_index, block_index]

                        if d_name not in model_architecture:
                            model_architecture[d_name] = []
                        model_architecture[d_name].append(keep_index.shape[0])
                        removed_filters_total_epoch += original_out_channels - keep_index.shape[0]
                        removed_parameters_total += (original_out_channels - keep_index.shape[0]) * \
                                                    last_keep_index.shape[0] * parameters.shape[2:].numel()
                        start_index = (keep_index.sort()[0], reset_index)

                    conv_tensor = model_adapter.get_layer(model, param_type, tensor_index, layer_index, block_index)
                    keep_index, reset_index = start_index

                    new_conv_tensor = create_conv_tensor(conv_tensor, out_channels_keep_indexes, None,
                                                         keep_index, reset_index).to(cuda)
                    model_adapter.set_layer(model, param_type, new_conv_tensor, tensor_index, layer_index,
                                            block_index)


                    if out_channels_keep_indexes is not None or len(out_channels_keep_indexes) != 0:
                        in_channels_keep_indexes.append(out_channels_keep_indexes[-1].sort()[0])

                    removed_filters_total_epoch += original_out_channels - keep_index.shape[0]
                    removed_parameters_total += (original_out_channels - keep_index.shape[0]) * \
                                                out_channels_keep_indexes[-1].shape[0] * parameters.shape[2:].numel()
                    out_channels_keep_indexes.append(keep_index.sort()[0])
                    if name not in model_architecture:
                        model_architecture[name] = []
                    model_architecture[name].append(keep_index.shape[0])


            elif param_type == ParameterType.DOWNSAMPLE_WEIGHTS:

                last_start_conv, tensor_index, layer_index, block_index = last_start_conv
                model_adapter.set_layer(model, ParameterType.DOWNSAMPLE_WEIGHTS, last_start_conv, tensor_index,
                                        layer_index,
                                        block_index)

                keep_index, reset_index = start_index

                in_channels_keep_indexes.append(last_keep_index.sort()[0])
                out_channels_keep_indexes.append(keep_index.sort()[0])

            elif param_type == ParameterType.BN_WEIGHT:
                bn_tensor = model_adapter.get_layer(model, param_type, tensor_index, layer_index, block_index)

                keep_index = out_channels_keep_indexes[-1]


                n_bn = create_new_bn(bn_tensor, keep_index, reset_index)
                model_adapter.set_layer(model, param_type, n_bn, tensor_index, layer_index, block_index)

                in_channels_keep_indexes.append(out_channels_keep_indexes[-1].sort()[0])
                out_channels_keep_indexes.append(keep_index.sort()[0])

            elif param_type == ParameterType.BN_BIAS:
                in_channels_keep_indexes.append(out_channels_keep_indexes[-1].sort()[0])
                out_channels_keep_indexes.append(keep_index.sort()[0])

            elif param_type == ParameterType.DOWNSAMPLE_BN_W:

                bn_tensor = model_adapter.get_layer(model, param_type, tensor_index, layer_index, block_index)

                keep_index, reset_index = start_index

                n_bn = create_new_bn(bn_tensor, keep_index, reset_index)
                model_adapter.set_layer(model, param_type, n_bn, tensor_index, layer_index, block_index)

                in_channels_keep_indexes.append(out_channels_keep_indexes[-1].sort()[0])
                out_channels_keep_indexes.append(keep_index.sort()[0])

            elif param_type == ParameterType.DOWNSAMPLE_BN_B:
                keep_index, reset_index = start_index

                in_channels_keep_indexes.append(out_channels_keep_indexes[-1].sort()[0])
                out_channels_keep_indexes.append(keep_index.sort()[0])


            elif param_type == ParameterType.CNN_BIAS:

                in_channels_keep_indexes.append(out_channels_keep_indexes[-1].sort()[0])
                out_channels_keep_indexes.append(out_channels_keep_indexes[-1])
    removed_filters_total = removed_filters_total_epoch



    optimizer = optim.SGD(model.parameters(), lr=optimizer.param_groups[0]['lr'], momentum=optimizer.param_groups[0]['momentum'])
    for epoch in range(before_epochs + 1, epochs_fn + 1):
        model.train()

        # One epoch step gradient for target
        optimizer.zero_grad()
        start = time.clock()
        total_loss = train(model, optimizer, cuda, train_loader,is_break)
        acc = eval(model, cuda, test_loader,is_break)

        if logger is not None:
            logger.log_scalar("prune_once_{}_after_target_val_acc".format(logger_id), acc, epoch)
            logger.log_scalar("prune_once_{}_number of filter removed", removed_filters_total, epoch)
            logger.log_scalar("prune_once_{}_acc_number of filter removed".format(logger_id), acc, removed_filters_total)
            logger.log_scalar("prune_once_{}_acc_number of parameters removed".format(logger_id), acc,
                              removed_parameters_total)

        loss_acc.append((total_loss / len(train_loader), acc))

    return loss_acc, model_architecture