示例#1
0
    def register(self, modules, Q_g, Q_a, W_star, use_patch, fix_rotation,
                 re_init):
        for m in self.modules():
            if isinstance(m, Bottleneck):
                print(m)
                self._update_bottleneck(m, modules, Q_g, Q_a, W_star,
                                        use_patch, fix_rotation)

        m = self.conv1
        if isinstance(m, nn.Sequential):
            m = m[1]
        if m in modules:
            self.conv1 = register_bottleneck_layer(m, Q_g[m], Q_a[m],
                                                   W_star[m], use_patch,
                                                   fix_rotation)
            update_QQ_dict(Q_g, Q_a, m, self.conv1[1])

        m = self.linear
        if isinstance(m, nn.Sequential):
            m = m[1]
        if m in modules:
            self.linear = register_bottleneck_layer(m, Q_g[m], Q_a[m],
                                                    W_star[m], use_patch,
                                                    fix_rotation)
            update_QQ_dict(Q_g, Q_a, m, self.linear[1])
        self._is_registered = True
        if re_init:
            self.apply(_weights_init)
示例#2
0
    def _update_bottleneck(self, bneck, modules, Q_g, Q_a, W_star, use_patch, fix_rotation):
        m = bneck.conv1
        if isinstance(m, nn.Sequential):
            m = m[1]
        if m in modules:
            bneck.conv1 = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation)
            update_QQ_dict(Q_g, Q_a, m, bneck.conv1[1])

        m = bneck.conv2
        if isinstance(m, nn.Sequential):
            m = m[1]
        if m in modules:
            bneck.conv2 = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation)
            update_QQ_dict(Q_g, Q_a, m, bneck.conv2[1])

        m = bneck.conv3
        if isinstance(m, nn.Sequential):
            m = m[1]
        if m in modules:
            bneck.conv3 = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation)
            update_QQ_dict(Q_g, Q_a, m, bneck.conv3[1])

        m = bneck.downsample
        if m is not None:
            if len(m) == 1 and m[0] in modules:
                m = m[0]
                bneck.downsample = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation)
                update_QQ_dict(Q_g, Q_a, m, bneck.downsample[1])
            elif len(m) == 3 and m[1] in modules:
                m = m[1]
                bneck.downsample = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation)
                update_QQ_dict(Q_g, Q_a, m, bneck.downsample[1])
            else:
                assert len(m) == 1 or len(m) == 3, 'Upexpected layer %s' % m
示例#3
0
 def register(self, modules, Q_g, Q_a, W_star, use_patch, fix_rotation, re_init):
     n_seqs = len(self.feature)
     for idx in range(n_seqs):
         m = self.feature[idx]
         if isinstance(m, nn.Sequential):
             m = m[1]
         if m in modules:
             self.feature[idx] = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation)
             update_QQ_dict(Q_g, Q_a, m, self.feature[idx][1])
     m = self.classifier
     if isinstance(m, nn.Sequential):
         m = m[1]
     if m in modules:
         self.classifier = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation)
         update_QQ_dict(Q_g, Q_a, m, self.classifier)
     self._is_registered = True
     if re_init:
         self.apply(_weights_init)
示例#4
0
    def register(self, modules, Q_g, Q_a, W_star, use_patch, fix_rotation, re_init):
        for m in self.modules():
            if isinstance(m, Bottleneck):
                self._update_bottleneck(m, modules, Q_g, Q_a, W_star, use_patch, fix_rotation)

        m = self.conv1
        if isinstance(m, nn.Sequential):
            m = m[1]
        if m in modules:
            self.conv1 = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation)
            update_QQ_dict(Q_g, Q_a, m, self.conv1[1])

        m = self.fc
        if isinstance(m, nn.Sequential):
            m = m[1]
        if m in modules:
            self.fc = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation)
            update_QQ_dict(Q_g, Q_a, m, self.fc[1])
        self._is_registered = True
        if re_init:
            raise NotImplementedError