def register(self, modules, Q_g, Q_a, W_star, use_patch, fix_rotation, re_init): for m in self.modules(): if isinstance(m, Bottleneck): print(m) self._update_bottleneck(m, modules, Q_g, Q_a, W_star, use_patch, fix_rotation) m = self.conv1 if isinstance(m, nn.Sequential): m = m[1] if m in modules: self.conv1 = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation) update_QQ_dict(Q_g, Q_a, m, self.conv1[1]) m = self.linear if isinstance(m, nn.Sequential): m = m[1] if m in modules: self.linear = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation) update_QQ_dict(Q_g, Q_a, m, self.linear[1]) self._is_registered = True if re_init: self.apply(_weights_init)
def _update_bottleneck(self, bneck, modules, Q_g, Q_a, W_star, use_patch, fix_rotation): m = bneck.conv1 if isinstance(m, nn.Sequential): m = m[1] if m in modules: bneck.conv1 = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation) update_QQ_dict(Q_g, Q_a, m, bneck.conv1[1]) m = bneck.conv2 if isinstance(m, nn.Sequential): m = m[1] if m in modules: bneck.conv2 = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation) update_QQ_dict(Q_g, Q_a, m, bneck.conv2[1]) m = bneck.conv3 if isinstance(m, nn.Sequential): m = m[1] if m in modules: bneck.conv3 = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation) update_QQ_dict(Q_g, Q_a, m, bneck.conv3[1]) m = bneck.downsample if m is not None: if len(m) == 1 and m[0] in modules: m = m[0] bneck.downsample = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation) update_QQ_dict(Q_g, Q_a, m, bneck.downsample[1]) elif len(m) == 3 and m[1] in modules: m = m[1] bneck.downsample = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation) update_QQ_dict(Q_g, Q_a, m, bneck.downsample[1]) else: assert len(m) == 1 or len(m) == 3, 'Upexpected layer %s' % m
def register(self, modules, Q_g, Q_a, W_star, use_patch, fix_rotation, re_init): n_seqs = len(self.feature) for idx in range(n_seqs): m = self.feature[idx] if isinstance(m, nn.Sequential): m = m[1] if m in modules: self.feature[idx] = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation) update_QQ_dict(Q_g, Q_a, m, self.feature[idx][1]) m = self.classifier if isinstance(m, nn.Sequential): m = m[1] if m in modules: self.classifier = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation) update_QQ_dict(Q_g, Q_a, m, self.classifier) self._is_registered = True if re_init: self.apply(_weights_init)
def register(self, modules, Q_g, Q_a, W_star, use_patch, fix_rotation, re_init): for m in self.modules(): if isinstance(m, Bottleneck): self._update_bottleneck(m, modules, Q_g, Q_a, W_star, use_patch, fix_rotation) m = self.conv1 if isinstance(m, nn.Sequential): m = m[1] if m in modules: self.conv1 = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation) update_QQ_dict(Q_g, Q_a, m, self.conv1[1]) m = self.fc if isinstance(m, nn.Sequential): m = m[1] if m in modules: self.fc = register_bottleneck_layer(m, Q_g[m], Q_a[m], W_star[m], use_patch, fix_rotation) update_QQ_dict(Q_g, Q_a, m, self.fc[1]) self._is_registered = True if re_init: raise NotImplementedError