def icfg(self, loader, iter, d_loss, cfg_U): timeLog('DDG::icfg ... ICFG with cfg_U=%d' % cfg_U) self.check_trainability() t_inc = 1 if self.verbose else 5 is_train = True for t in range(self.num_D()): sum_real = sum_fake = count = 0 for upd in range(cfg_U): sample, iter = get_next(loader, iter) num = sample[0].size(0) fake = self.generate(num, t=t) d_out_real = self.d_net(cast(sample[0]), self.d_params, is_train) d_out_fake = self.d_net(cast(fake), self.d_params, is_train) loss = d_loss(d_out_real, d_out_fake) loss.backward() self.d_optimizer.step() self.d_optimizer.zero_grad() with torch.no_grad(): sum_real += float(d_out_real.sum()) sum_fake += float(d_out_fake.sum()) count += num self.store_d_params(t) if t_inc > 0 and ((t + 1) % t_inc == 0 or t == self.num_D() - 1): logging(' t=%d: real,%s, fake,%s ' % (t + 1, sum_real / count, sum_fake / count)) raise_if_nan(sum_real) raise_if_nan(sum_fake) return iter, (sum_real - sum_fake) / count
def initialize_G(self, g_loss, cfg_N): timeLog('DDG::initialize_G ... Initializing tilde(G) ... ') z = self.z_gen(1) g_out = self.g_net(cast(z), self.g_params, False) img_dim = g_out.view(g_out.size(0), -1).size(1) batch_size = self.optim_config.x_batch_size z_dim = self.z_gen(1).size(1) params = {'proj.w': normal_(torch.Tensor(z_dim, img_dim), std=0.01)} params['proj.w'].requires_grad = True num_gened = 0 fakes = torch.Tensor(cfg_N, img_dim) zs = torch.Tensor(cfg_N, z_dim) with torch.no_grad(): while num_gened < cfg_N: num = min(batch_size, cfg_N - num_gened) z = self.z_gen(num) fake = torch.mm(z, params['proj.w']) fakes[num_gened:num_gened + num] = fake zs[num_gened:num_gened + num] = z num_gened += num to_pm1(fakes) # -> [-1,1] sz = [cfg_N] + list(g_out.size())[1:] dataset = TensorDataset(zs, fakes.view(sz)) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=torch.cuda.is_available()) self._approximate(loader, g_loss)
def fcn_G(input_dim, nn, imgsz, channels, requires_grad, depth=2): def gen_block_params(ni, no): return {'fc': utils.linear_params(ni, no),} def gen_group_params(ni, no, count): return {'block%d' % i: gen_block_params(ni if i == 0 else no, no) for i in range(count)} flat_params = utils.cast(utils.flatten({ 'group0': gen_group_params(input_dim, nn, depth), 'last_proj': utils.linear_params(nn, imgsz*imgsz*channels), })) if requires_grad: utils.set_requires_grad_except_bn_(flat_params) def block(x, params, base, mode): return F.relu(F.linear(x, params[base+'.fc.weight'], params[base+'.fc.bias']), inplace=True) def group(o, params, base, mode): for i in range(depth): o = block(o, params, '%s.block%d' % (base,i), mode) return o def f(input, params, mode): o = group(input, params, 'group0', mode) o = F.linear(o, params['last_proj.weight'], params['last_proj.bias']) o = torch.tanh(o) # o = o.view(o.size(0), channels, imgsz, imgsz) o = o.reshape(o.size(0), channels, imgsz, imgsz) return o return f, flat_params
def dcganx_D(nn0, imgsz, channels, # 1: gray-scale, 3: color norm_type, # 'bn', 'none' requires_grad, depth=3, leaky_slope=0.2, nodemul=2, do_bias=True): ker=5; padding=2 def gen_block_params(ni, no, k): return { 'conv0': conv2d_params(ni, no, k, do_bias), 'conv1': conv2d_params(no, no, 1, do_bias), 'bn0': utils.bnparams(no) if norm_type == 'bn' else None, 'bn1': utils.bnparams(no) if norm_type == 'bn' else None } def gen_group_params(ni, no, count): return {'block%d' % i: gen_block_params(ni if i == 0 else no, no, ker) for i in range(count)} count = 1 sz = imgsz // (2**depth) nn = nn0 p = { 'conv0': conv2d_params(channels, nn0, ker, do_bias) } for d in range(depth-1): p['group%d'%d] = gen_group_params(nn, nn*nodemul, count) nn = nn*nodemul p['fc'] = utils.linear_params(sz*sz*nn, 1) flat_params = utils.cast(utils.flatten(p)) if requires_grad: utils.set_requires_grad_except_bn_(flat_params) def block(x, params, base, mode, stride): o = F.conv2d(x, params[base+'.conv0.w'], params.get(base+'conv0.b'), stride=stride, padding=padding) if norm_type == 'bn': o = utils.batch_norm(o, params, base + '.bn0', mode) o = F.leaky_relu(o, negative_slope=leaky_slope, inplace=True) o = F.conv2d(o, params[base+'.conv1.w'], params.get(base+'conv1.b'), stride=1, padding=0) if norm_type == 'bn': o = utils.batch_norm(o, params, base + '.bn1', mode) o = F.leaky_relu(o, negative_slope=leaky_slope, inplace=True) return o def group(o, params, base, mode, stride=2): n = 1 for i in range(n): o = block(o, params, '%s.block%d' % (base,i), mode, stride if i == 0 else 1) return o def f(input, params, mode): o = F.conv2d(input, params['conv0.w'], params.get('conv0.b'), stride=2, padding=padding) o = F.leaky_relu(o, negative_slope=leaky_slope, inplace=True) for d in range(depth-1): o = group(o, params, 'group%d'%d, mode) o = o.view(o.size(0), -1) o = F.linear(o, params['fc.weight'], params['fc.bias']) return o return f, flat_params
def _approximate(self, loader, g_loss): if self.verbose: timeLog('DDG::_approximate using %d data points ...' % len(loader.dataset)) self.check_trainability() with torch.no_grad(): g_params = clone_params(self.g_params, do_copy_requires_grad=True) optimizer = self.optim_config.create_optimizer(g_params) mtr_loss = tnt.meter.AverageValueMeter() last_loss_mean = 99999999 is_train = True for epoch in range(self.optim_config.cfg_x_epo): for sample in loader: z = cast(sample[0]) target_fake = cast(sample[1]) fake = self.g_net(z, g_params, is_train) loss = g_loss(fake, target_fake) mtr_loss.add(float(loss)) loss.backward() optimizer.step() optimizer.zero_grad() loss_mean = mtr_loss.value()[0] if self.verbose: logging('%d ... %s ... ' % (epoch, str(loss_mean))) if loss_mean > last_loss_mean: self.optim_config.reduce_lr_(optimizer) raise_if_nan(loss_mean) last_loss_mean = loss_mean mtr_loss.reset() copy_params(src=g_params, dst=self.g_params)
def generate(self, num_gen, t=-1, do_return_z=False, batch_size=-1): assert num_gen > 0 if t < 0: t = self.num_D() if batch_size <= 0: batch_size = num_gen num_gened = 0 fakes = None zs = None is_train = False while num_gened < num_gen: num = min(batch_size, num_gen - num_gened) with torch.no_grad(): z = self.z_gen(num) fake = self.g_net(cast(z), self.g_params, is_train) for t0 in range(t): # fake.detach_(); fake = fake.detach() if fake.grad is not None: fake.grad.zero_() fake.requires_grad = True d_out = self.d_net(fake, self.get_d_params(t0), True) d_out.backward(torch.ones_like(d_out)) fake.data += self.cfg_eta * fake.grad.data if fakes is None: sz = [num_gen] + list(fake.size())[1:] fakes = torch.Tensor(torch.Size(sz), device=torch.device('cpu')) fakes[num_gened:num_gened + num] = fake.to(torch.device('cpu')) if do_return_z: if zs is None: sz = [num_gen] + list(z.size())[1:] zs = torch.Tensor(torch.Size(sz), device=z.device) zs[num_gened:num_gened + num] = z num_gened += num fakes.detach_() if do_return_z: return fakes, zs else: return fakes
def dcganx_G(input_dim, n0g, imgsz, channels, norm_type, # 'bn', 'none' requires_grad, depth=3, nodemul=2, do_bias=True): ker=5; padding=2; output_padding=1 def gen_block_T_params(ni, no, k): return { 'convT0': conv2dT_params(ni, no, k, do_bias), 'conv1': conv2d_params(no, no, 1, do_bias), 'bn0': utils.bnparams(no) if norm_type == 'bn' else None, 'bn1': utils.bnparams(no) if norm_type == 'bn' else None } def gen_group_T_params(ni, no, count): return {'block%d' % i: gen_block_T_params(ni if i == 0 else no, no, ker) for i in range(count)} count = 1 nn0 = n0g * (nodemul**(depth-1)) sz = imgsz // (2**depth) p = { 'proj': utils.linear_params(input_dim, nn0*sz*sz) } nn = nn0 for d in range(depth-1): p['group%d'%d] = gen_group_T_params(nn, nn//nodemul, count) nn = nn//nodemul p['last_convT'] = conv2dT_params(nn, channels, ker, do_bias) flat_params = utils.cast(utils.flatten(p)) if requires_grad: utils.set_requires_grad_except_bn_(flat_params) def block(x, params, base, mode, stride): o = F.relu(x, inplace=True) o = F.conv_transpose2d(o, params[base+'.convT0.w'], params.get(base+'.convT0.b'), stride=stride, padding=padding, output_padding=output_padding) if norm_type == 'bn': o = utils.batch_norm(o, params, base + '.bn0', mode) o = F.relu(o, inplace=True) o = F.conv2d(o, params[base+'.conv1.w'], params.get(base+'.conv1.b'), stride=1, padding=0) if norm_type == 'bn': o = utils.batch_norm(o, params, base + '.bn1', mode) return o def group(o, params, base, mode, stride=2): for i in range(count): o = block(o, params, '%s.block%d' % (base,i), mode, stride if i == 0 else 1) return o def f(input, params, mode): o = F.linear(input, params['proj.weight'], params['proj.bias']) o = o.view(input.size(0), nn0, sz, sz) for d in range(depth-1): o = group(o, params, 'group%d'%d, mode) o = F.relu(o, inplace=True) o = F.conv_transpose2d(o, params['last_convT.w'], params.get('last_convT.b'), stride=2, padding=padding, output_padding=output_padding) o = torch.tanh(o) return o return f, flat_params
def resnet4_G(input_dim, n0g, imgsz, channels, norm_type, # 'bn', 'none' requires_grad, do_bias=True): depth = 4 ker = 3 padding = (ker-1)//2 count = 1 def gen_resnet_G_block_params(ni, no, k, norm_type, do_bias): return { 'conv0': conv2d_params(ni, no, k, do_bias), 'conv1': conv2d_params(no, no, k, do_bias), 'convdim': utils.conv_params(ni, no, 1), 'bn': utils.bnparams(no) if norm_type == 'bn' else None } def gen_group_params(ni, no): return {'block%d' % i: gen_resnet_G_block_params(ni if i == 0 else no, no, ker, norm_type, do_bias) for i in range(count)} nn = n0g * (2**(depth-1)); sz = imgsz // (2**depth) flat_params = utils.cast(utils.flatten({ 'proj': utils.linear_params(input_dim, nn*sz*sz), 'group0': gen_group_params(nn, nn//2), 'group1': gen_group_params(nn//2, nn//4), 'group2': gen_group_params(nn//4, nn//8), 'group3': gen_group_params(nn//8, nn//8), 'last_conv': conv2d_params(nn//8, channels, ker, do_bias), })) if requires_grad: utils.set_requires_grad_except_bn_(flat_params) def block(x, params, base, mode, do_upsample): o = F.relu(x, inplace=True) if do_upsample: o = F.interpolate(o, scale_factor=2, mode='nearest') o = F.conv2d(o, params[base+'.conv0.w'], params.get(base+'.conv0.b'), padding=padding) o = F.relu(o, inplace=True) o = F.conv2d(o, params[base+'.conv1.w'], params.get(base+'.conv1.b'), padding=padding) if norm_type == 'bn': o = utils.batch_norm(o, params, base + '.bn', mode) xo = F.conv2d(x, params[base + '.convdim']) if do_upsample: return o + F.interpolate(xo, scale_factor=2, mode='nearest') else: return o + xo def group(o, params, base, mode, do_upsample): for i in range(count): o = block(o, params, '%s.block%d' % (base,i), mode, do_upsample if i == 0 else False) return o def show_shape(o, msg=''): print(o.size(), msg) def f(input, params, mode): o = F.linear(input, params['proj.weight'], params['proj.bias']) o = o.view(input.size(0), nn, sz, sz) o = group(o, params, 'group0', mode, do_upsample=True) o = group(o, params, 'group1', mode, do_upsample=True) o = group(o, params, 'group2', mode, do_upsample=True) o = group(o, params, 'group3', mode, do_upsample=True) o = F.relu(o, inplace=True) o = F.conv2d(o, params['last_conv.w'], params.get('last_conv.b'), padding=padding) o = torch.tanh(o) return o return f, flat_params
def resnet4_D(nn, imgsz, channels, # 1: gray-scale, 3: color norm_type, # 'bn', 'none' requires_grad, do_bias=True): depth =4 ker = 3 padding = (ker-1)//2 count = 1 def gen_group0_params(no): ni = channels return { 'block0' : { 'conv0': conv2d_params(ni, no, ker, do_bias), 'conv1': conv2d_params(no, no, ker, do_bias), 'convdim': utils.conv_params(ni, no, 1), 'bn': utils.bnparams(no) if norm_type == 'bn' else None }} def gen_resnet_D_block_params(ni, no, k, norm_type, do_bias): return { 'conv0': conv2d_params(ni, ni, k, do_bias), 'conv1': conv2d_params(ni, no, k, do_bias), 'convdim': utils.conv_params(ni, no, 1), 'bn': utils.bnparams(no) if norm_type == 'bn' else None } def gen_group_params(ni, no): return {'block%d' % i: gen_resnet_D_block_params(ni if i == 0 else no, no, ker, norm_type, do_bias) for i in range(count)} sz = imgsz // (2**depth) flat_params = utils.cast(utils.flatten({ 'group0': gen_group0_params(nn), 'group1': gen_group_params(nn, nn*2), 'group2': gen_group_params(nn*2, nn*4), 'group3': gen_group_params(nn*4, nn*8), 'fc': utils.linear_params(sz*sz*nn*8, 1), })) if requires_grad: utils.set_requires_grad_except_bn_(flat_params) def block(x, params, base, mode, do_downsample, is_first): o = x if not is_first: o = F.relu(o, inplace=True) o = F.conv2d(x, params[base+'.conv0.w'], params.get(base+'conv0.b'), padding=padding) o = F.relu(o, inplace=True) o = F.conv2d(o, params[base+'.conv1.w'], params.get(base+'conv1.b'), padding=padding) if norm_type == 'bn': o = utils.batch_norm(o, params, base + '.bn', mode) if do_downsample: o = F.avg_pool2d(o,2) x = F.avg_pool2d(x,2) if base + '.convdim' in params: return o + F.conv2d(x, params[base + '.convdim']) else: return o + x def group(o, params, base, mode, do_downsample, is_first=False): for i in range(count): o = block(o, params, '%s.block%d' % (base,i), mode, do_downsample=(do_downsample and i == count-1), is_first=(is_first and i == 0)) return o def f(input, params, mode): o = group(input, params, 'group0', mode, do_downsample=True, is_first=True) o = group(o, params, 'group1', mode, do_downsample=True) o = group(o, params, 'group2', mode, do_downsample=True) o = group(o, params, 'group3', mode, do_downsample=True) o = F.relu(o, inplace=True) o = o.view(o.size(0), -1) o = F.linear(o, params['fc.weight'], params['fc.bias']) return o return f, flat_params