示例#1
0
    def forward(self, input, weights=None):
        if input is None: return None
        if self.first_layer_free:
            output = [batch_conv(input, weights[0])]
            weights = weights[1:]
        else:
            output = [self.conv_first(input)]
        for i in range(self.n_downsample_S):
            if i >= self.params_free_layers or 'decoder' in self.netS:
                conv = getattr(self, 'down_%d' % i)(output[-1])
            else:
                conv = batch_conv(output[-1], weights[i], stride=2)
            output.append(conv)

        if self.netS == 'encoder':
            return output

        output = [output[-1]]
        for i in reversed(range(self.n_downsample_S)):
            if i >= self.params_free_layers:
                conv = getattr(self, 'up_%d' % i)(output[-1])
            else:
                conv = batch_conv(output[-1], weights[i], stride=0.5)
            output.append(conv)
        return output[::-1]
示例#2
0
    def forward(self, input, weights=None):
        if input is None: return None
        if self.first_layer_free:
            output = [batch_conv(input, weights[0])]
            weights = weights[1:]
        else:
            output = [self.conv_first(input)]
        for i in range(self.n_downsample_S):
            if i >= self.params_free_layers or self.decode:
                conv = getattr(self, 'down_%d' % i)(output[-1])
            else:
                conv = batch_conv(output[-1], weights[i], stride=2)
            output.append(conv)

        if not self.decode:
            return output

        if not self.unet:
            output = [output[-1]]
        for i in reversed(range(self.n_downsample_S)):
            input_i = output[-1]
            if self.unet and i != self.n_downsample_S - 1:
                input_i = torch.cat([input_i, output[i + 1]], dim=1)
            if i >= self.params_free_layers:
                conv = getattr(self, 'up_%d' % i)(input_i)
            else:
                conv = batch_conv(input_i, weights[i], stride=0.5)
            output.append(conv)
        if self.unet:
            output = output[self.n_downsample_S:]
        return output[::-1]
示例#3
0
 def forward(self, x, maps, weights=None, k=None):
     if not isinstance(maps, list): maps = [maps]
     out = self.batch_norm(x)
     for i in range(len(maps)):
         if maps[i] is None: continue
         m = F.interpolate(maps[i], size=x.size()[2:], mode='bilinear')
         if weights is None or (i != 0):
             s = str(i+1) if i > 0 else ''                                  
             gamma = getattr(self, 'mlp_gamma%s' % s)(m)
             beta = getattr(self, 'mlp_beta%s' % s)(m)
         else:
             j = min(i, len(weights[0])-1)
             gamma = batch_conv(m, weights[0][j])
             beta = batch_conv(m, weights[1][j])
         out = out * (1 + gamma) + beta                                  
     return out
示例#4
0
 def forward(self, input, weight=None, bias=None, stride=1):            
     return batch_conv(input, weight, bias, stride)