Exemplo n.º 1
0
    def icfg(self, loader, iter, d_loss, cfg_U):
        timeLog('DDG::icfg ... ICFG with cfg_U=%d' % cfg_U)
        self.check_trainability()
        t_inc = 1 if self.verbose else 5
        is_train = True
        for t in range(self.num_D()):
            sum_real = sum_fake = count = 0
            for upd in range(cfg_U):
                sample, iter = get_next(loader, iter)

                num = sample[0].size(0)
                fake = self.generate(num, t=t)
                d_out_real = self.d_net(cast(sample[0]), self.d_params,
                                        is_train)
                d_out_fake = self.d_net(cast(fake), self.d_params, is_train)
                loss = d_loss(d_out_real, d_out_fake)
                loss.backward()
                self.d_optimizer.step()
                self.d_optimizer.zero_grad()

                with torch.no_grad():
                    sum_real += float(d_out_real.sum())
                    sum_fake += float(d_out_fake.sum())
                    count += num

            self.store_d_params(t)

            if t_inc > 0 and ((t + 1) % t_inc == 0 or t == self.num_D() - 1):
                logging('  t=%d: real,%s, fake,%s ' %
                        (t + 1, sum_real / count, sum_fake / count))

        raise_if_nan(sum_real)
        raise_if_nan(sum_fake)

        return iter, (sum_real - sum_fake) / count
Exemplo n.º 2
0
    def initialize_G(self, g_loss, cfg_N):
        timeLog('DDG::initialize_G ... Initializing tilde(G) ... ')
        z = self.z_gen(1)
        g_out = self.g_net(cast(z), self.g_params, False)
        img_dim = g_out.view(g_out.size(0), -1).size(1)

        batch_size = self.optim_config.x_batch_size
        z_dim = self.z_gen(1).size(1)
        params = {'proj.w': normal_(torch.Tensor(z_dim, img_dim), std=0.01)}
        params['proj.w'].requires_grad = True

        num_gened = 0
        fakes = torch.Tensor(cfg_N, img_dim)
        zs = torch.Tensor(cfg_N, z_dim)
        with torch.no_grad():
            while num_gened < cfg_N:
                num = min(batch_size, cfg_N - num_gened)
                z = self.z_gen(num)
                fake = torch.mm(z, params['proj.w'])
                fakes[num_gened:num_gened + num] = fake
                zs[num_gened:num_gened + num] = z
                num_gened += num

        to_pm1(fakes)  # -> [-1,1]

        sz = [cfg_N] + list(g_out.size())[1:]
        dataset = TensorDataset(zs, fakes.view(sz))
        loader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            pin_memory=torch.cuda.is_available())
        self._approximate(loader, g_loss)
Exemplo n.º 3
0
def fcn_G(input_dim, nn, imgsz, channels, requires_grad, depth=2):
   def gen_block_params(ni, no):
      return {'fc': utils.linear_params(ni, no),}

   def gen_group_params(ni, no, count):
      return {'block%d' % i: gen_block_params(ni if i == 0 else no, no) for i in range(count)}

   flat_params = utils.cast(utils.flatten({
        'group0': gen_group_params(input_dim, nn, depth),
        'last_proj': utils.linear_params(nn, imgsz*imgsz*channels),
   }))

   if requires_grad:
      utils.set_requires_grad_except_bn_(flat_params)

   def block(x, params, base, mode):
      return F.relu(F.linear(x, params[base+'.fc.weight'], params[base+'.fc.bias']), inplace=True)

   def group(o, params, base, mode):
      for i in range(depth):
         o = block(o, params, '%s.block%d' % (base,i), mode)
      return o

   def f(input, params, mode):
      o = group(input, params, 'group0', mode)
      o = F.linear(o, params['last_proj.weight'], params['last_proj.bias'])
      o = torch.tanh(o)
#      o = o.view(o.size(0), channels, imgsz, imgsz)
      o = o.reshape(o.size(0), channels, imgsz, imgsz)      
      return o

   return f, flat_params
Exemplo n.º 4
0
def dcganx_D(nn0, imgsz,
             channels,    # 1: gray-scale, 3: color
             norm_type,   # 'bn', 'none'
             requires_grad, 
             depth=3, leaky_slope=0.2, nodemul=2, do_bias=True):
              
   ker=5; padding=2
   
   def gen_block_params(ni, no, k):
      return {
         'conv0': conv2d_params(ni, no, k, do_bias), 
         'conv1': conv2d_params(no, no, 1, do_bias), 
         'bn0': utils.bnparams(no) if norm_type == 'bn' else None, 
         'bn1': utils.bnparams(no) if norm_type == 'bn' else None
      }

   def gen_group_params(ni, no, count):
       return {'block%d' % i: gen_block_params(ni if i == 0 else no, no, ker) for i in range(count)}

   count = 1
   sz = imgsz // (2**depth)
   nn = nn0
   p = { 'conv0': conv2d_params(channels, nn0, ker, do_bias) }
   for d in range(depth-1):
      p['group%d'%d] = gen_group_params(nn, nn*nodemul, count)
      nn = nn*nodemul
   p['fc'] = utils.linear_params(sz*sz*nn, 1)
   flat_params = utils.cast(utils.flatten(p))

   if requires_grad:
      utils.set_requires_grad_except_bn_(flat_params)

   def block(x, params, base, mode, stride):
      o = F.conv2d(x, params[base+'.conv0.w'], params.get(base+'conv0.b'), stride=stride, padding=padding)
      if norm_type == 'bn':
         o = utils.batch_norm(o, params, base + '.bn0', mode)
      o = F.leaky_relu(o, negative_slope=leaky_slope, inplace=True)
      o = F.conv2d(o, params[base+'.conv1.w'], params.get(base+'conv1.b'), stride=1, padding=0)
      if norm_type == 'bn':
         o = utils.batch_norm(o, params, base + '.bn1', mode)
      o = F.leaky_relu(o, negative_slope=leaky_slope, inplace=True)
      return o

   def group(o, params, base, mode, stride=2):
      n = 1
      for i in range(n):
         o = block(o, params, '%s.block%d' % (base,i), mode, stride if i == 0 else 1)
      return o

   def f(input, params, mode):
      o = F.conv2d(input, params['conv0.w'], params.get('conv0.b'), stride=2, padding=padding)
      o = F.leaky_relu(o, negative_slope=leaky_slope, inplace=True)
      for d in range(depth-1):
         o = group(o, params, 'group%d'%d, mode)
      o = o.view(o.size(0), -1)
      o = F.linear(o, params['fc.weight'], params['fc.bias'])
      return o

   return f, flat_params
Exemplo n.º 5
0
    def _approximate(self, loader, g_loss):
        if self.verbose:
            timeLog('DDG::_approximate using %d data points ...' %
                    len(loader.dataset))
        self.check_trainability()
        with torch.no_grad():
            g_params = clone_params(self.g_params, do_copy_requires_grad=True)

        optimizer = self.optim_config.create_optimizer(g_params)
        mtr_loss = tnt.meter.AverageValueMeter()
        last_loss_mean = 99999999
        is_train = True
        for epoch in range(self.optim_config.cfg_x_epo):
            for sample in loader:
                z = cast(sample[0])
                target_fake = cast(sample[1])

                fake = self.g_net(z, g_params, is_train)

                loss = g_loss(fake, target_fake)
                mtr_loss.add(float(loss))
                loss.backward()

                optimizer.step()
                optimizer.zero_grad()

            loss_mean = mtr_loss.value()[0]
            if self.verbose:
                logging('%d ... %s ... ' % (epoch, str(loss_mean)))
            if loss_mean > last_loss_mean:
                self.optim_config.reduce_lr_(optimizer)
            raise_if_nan(loss_mean)

            last_loss_mean = loss_mean
            mtr_loss.reset()

        copy_params(src=g_params, dst=self.g_params)
Exemplo n.º 6
0
    def generate(self, num_gen, t=-1, do_return_z=False, batch_size=-1):
        assert num_gen > 0
        if t < 0:
            t = self.num_D()
        if batch_size <= 0:
            batch_size = num_gen

        num_gened = 0
        fakes = None
        zs = None
        is_train = False
        while num_gened < num_gen:
            num = min(batch_size, num_gen - num_gened)
            with torch.no_grad():
                z = self.z_gen(num)
                fake = self.g_net(cast(z), self.g_params, is_train)
            for t0 in range(t):
                #            fake.detach_();
                fake = fake.detach()
                if fake.grad is not None:
                    fake.grad.zero_()
                fake.requires_grad = True

                d_out = self.d_net(fake, self.get_d_params(t0), True)
                d_out.backward(torch.ones_like(d_out))
                fake.data += self.cfg_eta * fake.grad.data

            if fakes is None:
                sz = [num_gen] + list(fake.size())[1:]
                fakes = torch.Tensor(torch.Size(sz),
                                     device=torch.device('cpu'))

            fakes[num_gened:num_gened + num] = fake.to(torch.device('cpu'))

            if do_return_z:
                if zs is None:
                    sz = [num_gen] + list(z.size())[1:]
                    zs = torch.Tensor(torch.Size(sz), device=z.device)
                zs[num_gened:num_gened + num] = z

            num_gened += num

        fakes.detach_()
        if do_return_z:
            return fakes, zs
        else:
            return fakes
Exemplo n.º 7
0
def dcganx_G(input_dim, n0g, imgsz, channels,
             norm_type,  # 'bn', 'none'
             requires_grad, depth=3, 
             nodemul=2, do_bias=True):
              
   ker=5; padding=2; output_padding=1

   def gen_block_T_params(ni, no, k):
      return {
         'convT0': conv2dT_params(ni, no, k, do_bias), 
         'conv1': conv2d_params(no, no, 1, do_bias), 
         'bn0': utils.bnparams(no) if norm_type == 'bn' else None, 
         'bn1': utils.bnparams(no) if norm_type == 'bn' else None
      }

   def gen_group_T_params(ni, no, count):
       return {'block%d' % i: gen_block_T_params(ni if i == 0 else no, no, ker) for i in range(count)}

   count = 1
   nn0 = n0g * (nodemul**(depth-1))
   sz = imgsz // (2**depth)  
   p = { 'proj': utils.linear_params(input_dim, nn0*sz*sz) }
   nn = nn0
   for d in range(depth-1):
      p['group%d'%d] = gen_group_T_params(nn, nn//nodemul, count)
      nn = nn//nodemul
   p['last_convT'] = conv2dT_params(nn, channels, ker, do_bias)
   flat_params = utils.cast(utils.flatten(p))

   if requires_grad:
      utils.set_requires_grad_except_bn_(flat_params)

   def block(x, params, base, mode, stride):
      o = F.relu(x, inplace=True)
      o = F.conv_transpose2d(o, params[base+'.convT0.w'], params.get(base+'.convT0.b'),
                             stride=stride, padding=padding, output_padding=output_padding)
      if norm_type == 'bn':
         o = utils.batch_norm(o, params, base + '.bn0', mode)

      o = F.relu(o, inplace=True)
      o = F.conv2d(o, params[base+'.conv1.w'], params.get(base+'.conv1.b'),
                   stride=1, padding=0)
      if norm_type == 'bn':
         o = utils.batch_norm(o, params, base + '.bn1', mode)
      return o

   def group(o, params, base, mode, stride=2):
      for i in range(count):
         o = block(o, params, '%s.block%d' % (base,i), mode, stride if i == 0 else 1)
      return o

   def f(input, params, mode):
      o = F.linear(input, params['proj.weight'], params['proj.bias'])
      o = o.view(input.size(0), nn0, sz, sz)
      for d in range(depth-1):
        o = group(o, params, 'group%d'%d, mode)
      o = F.relu(o, inplace=True)
      o = F.conv_transpose2d(o, params['last_convT.w'], params.get('last_convT.b'), stride=2,
                             padding=padding, output_padding=output_padding)
      o = torch.tanh(o)
      return o

   return f, flat_params
Exemplo n.º 8
0
def resnet4_G(input_dim, n0g, imgsz, channels,
             norm_type,  # 'bn', 'none'
             requires_grad,
             do_bias=True):         
   depth = 4
   ker = 3
   padding = (ker-1)//2
   count = 1

   def gen_resnet_G_block_params(ni, no, k, norm_type, do_bias):
      return {
         'conv0': conv2d_params(ni, no, k, do_bias), 
         'conv1': conv2d_params(no, no, k, do_bias), 
         'convdim': utils.conv_params(ni, no, 1), 
         'bn': utils.bnparams(no) if norm_type == 'bn' else None
      }

   def gen_group_params(ni, no):
       return {'block%d' % i: gen_resnet_G_block_params(ni if i == 0 else no, no, ker, norm_type, do_bias) for i in range(count)}

   nn = n0g * (2**(depth-1)); sz = imgsz // (2**depth)
   flat_params = utils.cast(utils.flatten({
        'proj': utils.linear_params(input_dim, nn*sz*sz),
        'group0': gen_group_params(nn,    nn//2),
        'group1': gen_group_params(nn//2, nn//4),
        'group2': gen_group_params(nn//4, nn//8),
        'group3': gen_group_params(nn//8, nn//8),        
        'last_conv': conv2d_params(nn//8, channels, ker, do_bias),
   }))

   if requires_grad:
      utils.set_requires_grad_except_bn_(flat_params)

   def block(x, params, base, mode, do_upsample):
      o = F.relu(x, inplace=True)
      if do_upsample:
        o = F.interpolate(o, scale_factor=2, mode='nearest')
            
      o = F.conv2d(o, params[base+'.conv0.w'], params.get(base+'.conv0.b'), padding=padding)
      o = F.relu(o, inplace=True)
      o = F.conv2d(o, params[base+'.conv1.w'], params.get(base+'.conv1.b'), padding=padding)
      if norm_type == 'bn':
         o = utils.batch_norm(o, params, base + '.bn', mode)
         
      xo = F.conv2d(x, params[base + '.convdim']) 
      if do_upsample:
         return o + F.interpolate(xo, scale_factor=2, mode='nearest')
      else:
         return o + xo
 
   def group(o, params, base, mode, do_upsample):
      for i in range(count):
         o = block(o, params, '%s.block%d' % (base,i), mode, do_upsample if i == 0 else False)
      return o

   def show_shape(o, msg=''):
      print(o.size(), msg)

   def f(input, params, mode):
      o = F.linear(input, params['proj.weight'], params['proj.bias'])
      o = o.view(input.size(0), nn, sz, sz)
      o = group(o, params, 'group0', mode, do_upsample=True)
      o = group(o, params, 'group1', mode, do_upsample=True)
      o = group(o, params, 'group2', mode, do_upsample=True)
      o = group(o, params, 'group3', mode, do_upsample=True)
      o = F.relu(o, inplace=True)
      o = F.conv2d(o, params['last_conv.w'], params.get('last_conv.b'), padding=padding)
      o = torch.tanh(o)
      return o

   return f, flat_params   
   
Exemplo n.º 9
0
def resnet4_D(nn, imgsz,
              channels,    # 1: gray-scale, 3: color
              norm_type,  # 'bn', 'none'
              requires_grad,
              do_bias=True):             
   depth =4
   ker = 3
   padding = (ker-1)//2
   count = 1

   def gen_group0_params(no):
      ni = channels
      return { 'block0' : {
         'conv0': conv2d_params(ni, no, ker, do_bias), 
         'conv1': conv2d_params(no, no, ker, do_bias), 
         'convdim': utils.conv_params(ni, no, 1), 
         'bn': utils.bnparams(no) if norm_type == 'bn' else None
      }}

   def gen_resnet_D_block_params(ni, no, k, norm_type, do_bias):
      return {
         'conv0': conv2d_params(ni, ni, k, do_bias), 
         'conv1': conv2d_params(ni, no, k, do_bias), 
         'convdim': utils.conv_params(ni, no, 1), 
         'bn': utils.bnparams(no) if norm_type == 'bn' else None
      }

   def gen_group_params(ni, no):
       return {'block%d' % i: gen_resnet_D_block_params(ni if i == 0 else no, no, ker, norm_type, do_bias) for i in range(count)}

   sz = imgsz // (2**depth)
   flat_params = utils.cast(utils.flatten({
        'group0': gen_group0_params(nn),
        'group1': gen_group_params(nn,   nn*2),
        'group2': gen_group_params(nn*2, nn*4),
        'group3': gen_group_params(nn*4, nn*8),        
        'fc': utils.linear_params(sz*sz*nn*8, 1),
   }))

   if requires_grad:
      utils.set_requires_grad_except_bn_(flat_params)

   def block(x, params, base, mode, do_downsample, is_first):
      o = x
      if not is_first:
         o = F.relu(o, inplace=True)   
      o = F.conv2d(x, params[base+'.conv0.w'], params.get(base+'conv0.b'), padding=padding)
      o = F.relu(o, inplace=True)      
      o = F.conv2d(o, params[base+'.conv1.w'], params.get(base+'conv1.b'), padding=padding)
      if norm_type == 'bn':
         o = utils.batch_norm(o, params, base + '.bn', mode)
 
      if do_downsample:
         o = F.avg_pool2d(o,2)
         x = F.avg_pool2d(x,2)
      
      if base + '.convdim' in params:
         return o + F.conv2d(x, params[base + '.convdim'])
      else:
         return o + x


   def group(o, params, base, mode, do_downsample, is_first=False):
      for i in range(count):
         o = block(o, params, '%s.block%d' % (base,i), mode, 
                   do_downsample=(do_downsample and i == count-1), 
                   is_first=(is_first and i == 0))                   
      return o

   def f(input, params, mode):
      o = group(input, params, 'group0', mode, do_downsample=True, is_first=True)
      o = group(o, params, 'group1', mode, do_downsample=True)
      o = group(o, params, 'group2', mode, do_downsample=True)
      o = group(o, params, 'group3', mode, do_downsample=True)      
      o = F.relu(o, inplace=True)
      o = o.view(o.size(0), -1)
      o = F.linear(o, params['fc.weight'], params['fc.bias'])
      return o

   return f, flat_params