Ejemplo n.º 1
0
    def expected_flops(self, input_alpha, in_height, in_width):
        if self.stride == 1 and self.inplanes == self.outplanes and self.alpha2 is None:
            self.alpha2 = input_alpha

        if self.t == 1 and self.alpha1 is None:
            self.alpha1 = input_alpha

        e_in_channel = input_alpha.expected_channel()
        if not self.t == 1:
            e_conv1_out = self.alpha1.expected_channel()
            e_conv1_flops, out_height, out_width = conv_compute_flops(
                self.conv1, in_height, in_width, e_in_channel, e_conv1_out)
        else:
            e_conv1_out = e_in_channel
            e_conv1_flops = 0
            out_height = in_height
            out_width = in_width
        e_conv3_out = self.alpha2.expected_channel()

        e_conv2_flops, out_height, out_width = conv_compute_flops(
            self.conv2, out_height, out_width, e_conv1_out, e_conv1_out)
        e_conv3_flops, out_height, out_width = conv_compute_flops(
            self.conv3, out_height, out_width, e_conv1_out, e_conv3_out)
        e_flops = e_conv1_flops + e_conv2_flops + e_conv3_flops

        return e_flops, out_height, out_width
Ejemplo n.º 2
0
    def expected_flops(self, in_height, in_width):
        # conv1
        total_flops, out_height, out_width = conv_compute_flops(
            self.conv1,
            in_height,
            in_width,
            e_out_ch=self.alpha1.expected_channel())
        cur_alpha = self.alpha1

        # max_pool
        out_height, out_width = out_height // 2, out_width // 2
        # bottlenecks
        for m in self.modules():
            if isinstance(m, (DMCPBasicBlock, DMCPBottleneck)):
                flops, out_height, out_width = m.expected_flops(
                    cur_alpha, out_height, out_width)
                total_flops += flops
                if isinstance(m, DMCPBottleneck):
                    cur_alpha = m.alpha3
                else:
                    cur_alpha = m.alpha2

        # fc
        flops, out_height, out_width = conv_compute_flops(
            self.fc, 1, 1, e_in_ch=cur_alpha.expected_channel())
        total_flops += flops

        return total_flops / 1e6
Ejemplo n.º 3
0
    def expected_flops(self, in_height, in_width):
        total_flops, out_height, out_width = conv_compute_flops(
            self.conv1,
            in_height,
            in_width,
            e_out_ch=self.alpha1.expected_channel())
        cur_alpha = self.alpha1

        for m in self.modules():
            if isinstance(m, DMCPInvertedResidual):
                flops, out_height, out_width = m.expected_flops(
                    cur_alpha, out_height, out_width)
                total_flops += flops
                cur_alpha = m.alpha2

        # fc
        flops, _, _ = conv_compute_flops(self.conv_last, out_height, out_width,
                                         cur_alpha.expected_channel(),
                                         self.alpha_last.expected_channel())
        total_flops += flops
        flops, _, _ = conv_compute_flops(
            self.fc, 1, 1, e_in_ch=self.alpha_last.expected_channel())
        total_flops += flops

        return total_flops / 1e6
Ejemplo n.º 4
0
    def expected_flops(self, in_alpha, in_height, in_width):
        if self.downsample is None and self.alpha3 is None:
            self.alpha3 = in_alpha
        e_in_channel = in_alpha.expected_channel()

        # conv1
        e_conv1_out = self.alpha1.expected_channel()
        e_conv1_flops, out_height, out_width = conv_compute_flops(
            self.conv1, in_height, in_width, e_in_channel, e_conv1_out)

        # conv2
        e_conv2_out = self.alpha2.expected_channel()
        e_conv2_flops, out_height, out_width = conv_compute_flops(
            self.conv2, out_height, out_width, e_conv1_out, e_conv2_out)

        # conv3
        e_conv3_out = self.alpha3.expected_channel()
        e_conv3_flops, out_height, out_width = conv_compute_flops(
            self.conv3, out_height, out_width, e_conv2_out, e_conv3_out)

        e_flops = e_conv1_flops + e_conv2_flops + e_conv3_flops
        # downsample
        if self.downsample is not None:
            e_downsample_flops, out_height, out_width = conv_compute_flops(
                self.downsample[0], in_height, in_width, e_in_channel,
                e_conv3_out)
            e_flops += e_downsample_flops

        return e_flops, out_height, out_width