示例#1
0
def test_calculate_flops(op_type, input_shape, output_shape, weight_shape,
                         kernel_shape, bias_shape, flops):
    assert flops == calculate_flops(
        op_type,
        input_shape=input_shape,
        output_shape=output_shape,
        weight_shape=weight_shape,
        kernel_shape=kernel_shape,
        bias_shape=bias_shape,
    )
示例#2
0
def test_calculate_flops_negatives(op_type, input_shape, output_shape,
                                   weight_shape, kernel_shape, bias_shape):
    assert (calculate_flops(
        op_type,
        input_shape=input_shape,
        output_shape=output_shape,
        weight_shape=weight_shape,
        kernel_shape=kernel_shape,
        bias_shape=bias_shape,
    ) is None)
示例#3
0
    def __init__(
        self,
        model: Union[ModelProto, None],
        node: Union[Any, None],
        node_shape: Union[NodeShape, None] = None,
        **kwargs,
    ):
        if model is None and node is None:
            self._id = kwargs["id"]
            self._op_type = kwargs["op_type"]
            self._input_names = kwargs["input_names"]
            self._output_names = kwargs["output_names"]
            self._input_shapes = kwargs["input_shapes"]
            self._output_shapes = kwargs["output_shapes"]
            self._params = kwargs["params"]
            self._prunable = kwargs["prunable"]
            self._prunable_params_zeroed = kwargs["prunable_params_zeroed"]
            self._weight_name = kwargs["weight_name"]
            self._weight_shape = kwargs["weight_shape"]
            self._bias_name = kwargs["bias_name"]
            self._bias_shape = kwargs["bias_shape"]
            self._attributes = kwargs["attributes"]
            self._flops = kwargs["flops"]
            self._prunable_equation_sensitivity = (
                kwargs["prunable_equation_sensitivity"]
                if "prunable_equation_sensitivity" in kwargs
                else None
            )

            return

        if model is None or node is None:
            raise ValueError("both model and node must not be None")

        self._id = extract_node_id(node)
        self._op_type = node.op_type
        self._input_names = get_node_inputs(model, node)
        self._output_names = get_node_outputs(model, node)

        if node_shape is None:
            self._input_shapes = None
            self._output_shapes = None
        else:
            self._input_shapes = node_shape.input_shapes
            self._output_shapes = node_shape.output_shapes

        self._params = 0
        self._prunable = is_prunable_node(model, node)
        self._prunable_params = 0
        self._prunable_params_zeroed = 0
        self._weight_name = None
        self._weight_shape = None
        self._bias_name = None
        self._bias_shape = None
        self._attributes = get_node_attributes(node)

        if self._prunable:
            weight, bias = get_node_params(model, node)
            self._params += weight.val.size
            self._prunable_params += weight.val.size
            self._prunable_params_zeroed += weight.val.size - numpy.count_nonzero(
                weight.val
            )
            self._weight_name = weight.name
            self._weight_shape = [s for s in weight.val.shape]

            if bias is not None:
                self._bias_name = bias.name
                self._params += bias.val.size
                self._bias_shape = [s for s in bias.val.shape]

        kernel_shape = get_kernel_shape(self._attributes)
        self._flops = calculate_flops(
            self._op_type,
            input_shape=self._input_shapes,
            output_shape=self._output_shapes,
            weight_shape=self._weight_shape,
            kernel_shape=kernel_shape,
            bias_shape=self._bias_shape,
            attributes=self._attributes,
        )

        self._prunable_equation_sensitivity = (
            pruning_loss_sens_approx(
                self._input_shapes,
                self._output_shapes,
                self._params,
                apply_shape_change_mult=True,
            )
            if self._prunable
            else None
        )