コード例 #1
0
def fcn_G(input_dim, nn, imgsz, channels, requires_grad, depth=2):
   def gen_block_params(ni, no):
      return {'fc': utils.linear_params(ni, no),}

   def gen_group_params(ni, no, count):
      return {'block%d' % i: gen_block_params(ni if i == 0 else no, no) for i in range(count)}

   flat_params = utils.cast(utils.flatten({
        'group0': gen_group_params(input_dim, nn, depth),
        'last_proj': utils.linear_params(nn, imgsz*imgsz*channels),
   }))

   if requires_grad:
      utils.set_requires_grad_except_bn_(flat_params)

   def block(x, params, base, mode):
      return F.relu(F.linear(x, params[base+'.fc.weight'], params[base+'.fc.bias']), inplace=True)

   def group(o, params, base, mode):
      for i in range(depth):
         o = block(o, params, '%s.block%d' % (base,i), mode)
      return o

   def f(input, params, mode):
      o = group(input, params, 'group0', mode)
      o = F.linear(o, params['last_proj.weight'], params['last_proj.bias'])
      o = torch.tanh(o)
#      o = o.view(o.size(0), channels, imgsz, imgsz)
      o = o.reshape(o.size(0), channels, imgsz, imgsz)      
      return o

   return f, flat_params
コード例 #2
0
def dcganx_D(nn0, imgsz,
             channels,    # 1: gray-scale, 3: color
             norm_type,   # 'bn', 'none'
             requires_grad, 
             depth=3, leaky_slope=0.2, nodemul=2, do_bias=True):
              
   ker=5; padding=2
   
   def gen_block_params(ni, no, k):
      return {
         'conv0': conv2d_params(ni, no, k, do_bias), 
         'conv1': conv2d_params(no, no, 1, do_bias), 
         'bn0': utils.bnparams(no) if norm_type == 'bn' else None, 
         'bn1': utils.bnparams(no) if norm_type == 'bn' else None
      }

   def gen_group_params(ni, no, count):
       return {'block%d' % i: gen_block_params(ni if i == 0 else no, no, ker) for i in range(count)}

   count = 1
   sz = imgsz // (2**depth)
   nn = nn0
   p = { 'conv0': conv2d_params(channels, nn0, ker, do_bias) }
   for d in range(depth-1):
      p['group%d'%d] = gen_group_params(nn, nn*nodemul, count)
      nn = nn*nodemul
   p['fc'] = utils.linear_params(sz*sz*nn, 1)
   flat_params = utils.cast(utils.flatten(p))

   if requires_grad:
      utils.set_requires_grad_except_bn_(flat_params)

   def block(x, params, base, mode, stride):
      o = F.conv2d(x, params[base+'.conv0.w'], params.get(base+'conv0.b'), stride=stride, padding=padding)
      if norm_type == 'bn':
         o = utils.batch_norm(o, params, base + '.bn0', mode)
      o = F.leaky_relu(o, negative_slope=leaky_slope, inplace=True)
      o = F.conv2d(o, params[base+'.conv1.w'], params.get(base+'conv1.b'), stride=1, padding=0)
      if norm_type == 'bn':
         o = utils.batch_norm(o, params, base + '.bn1', mode)
      o = F.leaky_relu(o, negative_slope=leaky_slope, inplace=True)
      return o

   def group(o, params, base, mode, stride=2):
      n = 1
      for i in range(n):
         o = block(o, params, '%s.block%d' % (base,i), mode, stride if i == 0 else 1)
      return o

   def f(input, params, mode):
      o = F.conv2d(input, params['conv0.w'], params.get('conv0.b'), stride=2, padding=padding)
      o = F.leaky_relu(o, negative_slope=leaky_slope, inplace=True)
      for d in range(depth-1):
         o = group(o, params, 'group%d'%d, mode)
      o = o.view(o.size(0), -1)
      o = F.linear(o, params['fc.weight'], params['fc.bias'])
      return o

   return f, flat_params
コード例 #3
0
def dcganx_G(input_dim, n0g, imgsz, channels,
             norm_type,  # 'bn', 'none'
             requires_grad, depth=3, 
             nodemul=2, do_bias=True):
              
   ker=5; padding=2; output_padding=1

   def gen_block_T_params(ni, no, k):
      return {
         'convT0': conv2dT_params(ni, no, k, do_bias), 
         'conv1': conv2d_params(no, no, 1, do_bias), 
         'bn0': utils.bnparams(no) if norm_type == 'bn' else None, 
         'bn1': utils.bnparams(no) if norm_type == 'bn' else None
      }

   def gen_group_T_params(ni, no, count):
       return {'block%d' % i: gen_block_T_params(ni if i == 0 else no, no, ker) for i in range(count)}

   count = 1
   nn0 = n0g * (nodemul**(depth-1))
   sz = imgsz // (2**depth)  
   p = { 'proj': utils.linear_params(input_dim, nn0*sz*sz) }
   nn = nn0
   for d in range(depth-1):
      p['group%d'%d] = gen_group_T_params(nn, nn//nodemul, count)
      nn = nn//nodemul
   p['last_convT'] = conv2dT_params(nn, channels, ker, do_bias)
   flat_params = utils.cast(utils.flatten(p))

   if requires_grad:
      utils.set_requires_grad_except_bn_(flat_params)

   def block(x, params, base, mode, stride):
      o = F.relu(x, inplace=True)
      o = F.conv_transpose2d(o, params[base+'.convT0.w'], params.get(base+'.convT0.b'),
                             stride=stride, padding=padding, output_padding=output_padding)
      if norm_type == 'bn':
         o = utils.batch_norm(o, params, base + '.bn0', mode)

      o = F.relu(o, inplace=True)
      o = F.conv2d(o, params[base+'.conv1.w'], params.get(base+'.conv1.b'),
                   stride=1, padding=0)
      if norm_type == 'bn':
         o = utils.batch_norm(o, params, base + '.bn1', mode)
      return o

   def group(o, params, base, mode, stride=2):
      for i in range(count):
         o = block(o, params, '%s.block%d' % (base,i), mode, stride if i == 0 else 1)
      return o

   def f(input, params, mode):
      o = F.linear(input, params['proj.weight'], params['proj.bias'])
      o = o.view(input.size(0), nn0, sz, sz)
      for d in range(depth-1):
        o = group(o, params, 'group%d'%d, mode)
      o = F.relu(o, inplace=True)
      o = F.conv_transpose2d(o, params['last_convT.w'], params.get('last_convT.b'), stride=2,
                             padding=padding, output_padding=output_padding)
      o = torch.tanh(o)
      return o

   return f, flat_params
コード例 #4
0
def resnet4_G(input_dim, n0g, imgsz, channels,
             norm_type,  # 'bn', 'none'
             requires_grad,
             do_bias=True):         
   depth = 4
   ker = 3
   padding = (ker-1)//2
   count = 1

   def gen_resnet_G_block_params(ni, no, k, norm_type, do_bias):
      return {
         'conv0': conv2d_params(ni, no, k, do_bias), 
         'conv1': conv2d_params(no, no, k, do_bias), 
         'convdim': utils.conv_params(ni, no, 1), 
         'bn': utils.bnparams(no) if norm_type == 'bn' else None
      }

   def gen_group_params(ni, no):
       return {'block%d' % i: gen_resnet_G_block_params(ni if i == 0 else no, no, ker, norm_type, do_bias) for i in range(count)}

   nn = n0g * (2**(depth-1)); sz = imgsz // (2**depth)
   flat_params = utils.cast(utils.flatten({
        'proj': utils.linear_params(input_dim, nn*sz*sz),
        'group0': gen_group_params(nn,    nn//2),
        'group1': gen_group_params(nn//2, nn//4),
        'group2': gen_group_params(nn//4, nn//8),
        'group3': gen_group_params(nn//8, nn//8),        
        'last_conv': conv2d_params(nn//8, channels, ker, do_bias),
   }))

   if requires_grad:
      utils.set_requires_grad_except_bn_(flat_params)

   def block(x, params, base, mode, do_upsample):
      o = F.relu(x, inplace=True)
      if do_upsample:
        o = F.interpolate(o, scale_factor=2, mode='nearest')
            
      o = F.conv2d(o, params[base+'.conv0.w'], params.get(base+'.conv0.b'), padding=padding)
      o = F.relu(o, inplace=True)
      o = F.conv2d(o, params[base+'.conv1.w'], params.get(base+'.conv1.b'), padding=padding)
      if norm_type == 'bn':
         o = utils.batch_norm(o, params, base + '.bn', mode)
         
      xo = F.conv2d(x, params[base + '.convdim']) 
      if do_upsample:
         return o + F.interpolate(xo, scale_factor=2, mode='nearest')
      else:
         return o + xo
 
   def group(o, params, base, mode, do_upsample):
      for i in range(count):
         o = block(o, params, '%s.block%d' % (base,i), mode, do_upsample if i == 0 else False)
      return o

   def show_shape(o, msg=''):
      print(o.size(), msg)

   def f(input, params, mode):
      o = F.linear(input, params['proj.weight'], params['proj.bias'])
      o = o.view(input.size(0), nn, sz, sz)
      o = group(o, params, 'group0', mode, do_upsample=True)
      o = group(o, params, 'group1', mode, do_upsample=True)
      o = group(o, params, 'group2', mode, do_upsample=True)
      o = group(o, params, 'group3', mode, do_upsample=True)
      o = F.relu(o, inplace=True)
      o = F.conv2d(o, params['last_conv.w'], params.get('last_conv.b'), padding=padding)
      o = torch.tanh(o)
      return o

   return f, flat_params   
   
コード例 #5
0
def resnet4_D(nn, imgsz,
              channels,    # 1: gray-scale, 3: color
              norm_type,  # 'bn', 'none'
              requires_grad,
              do_bias=True):             
   depth =4
   ker = 3
   padding = (ker-1)//2
   count = 1

   def gen_group0_params(no):
      ni = channels
      return { 'block0' : {
         'conv0': conv2d_params(ni, no, ker, do_bias), 
         'conv1': conv2d_params(no, no, ker, do_bias), 
         'convdim': utils.conv_params(ni, no, 1), 
         'bn': utils.bnparams(no) if norm_type == 'bn' else None
      }}

   def gen_resnet_D_block_params(ni, no, k, norm_type, do_bias):
      return {
         'conv0': conv2d_params(ni, ni, k, do_bias), 
         'conv1': conv2d_params(ni, no, k, do_bias), 
         'convdim': utils.conv_params(ni, no, 1), 
         'bn': utils.bnparams(no) if norm_type == 'bn' else None
      }

   def gen_group_params(ni, no):
       return {'block%d' % i: gen_resnet_D_block_params(ni if i == 0 else no, no, ker, norm_type, do_bias) for i in range(count)}

   sz = imgsz // (2**depth)
   flat_params = utils.cast(utils.flatten({
        'group0': gen_group0_params(nn),
        'group1': gen_group_params(nn,   nn*2),
        'group2': gen_group_params(nn*2, nn*4),
        'group3': gen_group_params(nn*4, nn*8),        
        'fc': utils.linear_params(sz*sz*nn*8, 1),
   }))

   if requires_grad:
      utils.set_requires_grad_except_bn_(flat_params)

   def block(x, params, base, mode, do_downsample, is_first):
      o = x
      if not is_first:
         o = F.relu(o, inplace=True)   
      o = F.conv2d(x, params[base+'.conv0.w'], params.get(base+'conv0.b'), padding=padding)
      o = F.relu(o, inplace=True)      
      o = F.conv2d(o, params[base+'.conv1.w'], params.get(base+'conv1.b'), padding=padding)
      if norm_type == 'bn':
         o = utils.batch_norm(o, params, base + '.bn', mode)
 
      if do_downsample:
         o = F.avg_pool2d(o,2)
         x = F.avg_pool2d(x,2)
      
      if base + '.convdim' in params:
         return o + F.conv2d(x, params[base + '.convdim'])
      else:
         return o + x


   def group(o, params, base, mode, do_downsample, is_first=False):
      for i in range(count):
         o = block(o, params, '%s.block%d' % (base,i), mode, 
                   do_downsample=(do_downsample and i == count-1), 
                   is_first=(is_first and i == 0))                   
      return o

   def f(input, params, mode):
      o = group(input, params, 'group0', mode, do_downsample=True, is_first=True)
      o = group(o, params, 'group1', mode, do_downsample=True)
      o = group(o, params, 'group2', mode, do_downsample=True)
      o = group(o, params, 'group3', mode, do_downsample=True)      
      o = F.relu(o, inplace=True)
      o = o.view(o.size(0), -1)
      o = F.linear(o, params['fc.weight'], params['fc.bias'])
      return o

   return f, flat_params   
コード例 #6
0
 def gen_block_params(ni, no):
    return {'fc': utils.linear_params(ni, no),}