def __init__(self, opt, d_config, g_config, z_gen, optim_config, from_file=None): assert opt.cfg_T > 0 self.verbose = opt.verbose self.d_params_list = [ d_config(requires_grad=False)[1] for i in range(opt.cfg_T) ] self.d_net, self.d_params = d_config(requires_grad=True) self.g_net, self.g_params = g_config(requires_grad=True) self.z_gen = z_gen self.cfg_eta = opt.cfg_eta self.optim_config = optim_config self.d_optimizer = None if optim_config is not None: self.d_optimizer = optim_config.create_optimizer(self.d_params) if from_file is not None: self.load(from_file) logging('---- D ----') if self.verbose: print_params(self.d_params) print_num_params(self.d_params) logging('---- G ----') if self.verbose: print_params(self.g_params) print_num_params(self.g_params)
def icfg(self, loader, iter, d_loss, cfg_U): timeLog('DDG::icfg ... ICFG with cfg_U=%d' % cfg_U) self.check_trainability() t_inc = 1 if self.verbose else 5 is_train = True for t in range(self.num_D()): sum_real = sum_fake = count = 0 for upd in range(cfg_U): sample, iter = get_next(loader, iter) num = sample[0].size(0) fake = self.generate(num, t=t) d_out_real = self.d_net(cast(sample[0]), self.d_params, is_train) d_out_fake = self.d_net(cast(fake), self.d_params, is_train) loss = d_loss(d_out_real, d_out_fake) loss.backward() self.d_optimizer.step() self.d_optimizer.zero_grad() with torch.no_grad(): sum_real += float(d_out_real.sum()) sum_fake += float(d_out_fake.sum()) count += num self.store_d_params(t) if t_inc > 0 and ((t + 1) % t_inc == 0 or t == self.num_D() - 1): logging(' t=%d: real,%s, fake,%s ' % (t + 1, sum_real / count, sum_fake / count)) raise_if_nan(sum_real) raise_if_nan(sum_fake) return iter, (sum_real - sum_fake) / count
def g_config(requires_grad): # G if opt.g_model == DCGANx: return netdef.dcganx_G(opt.z_dim, opt.g_dim, opt.image_size, opt.channels, opt.norm_type, requires_grad, depth=opt.g_depth, do_bias=not opt.do_no_bias) elif opt.g_model == Resnet4: if opt.g_depth != 4: logging('WARNING: d_depth is ignored as d_model is Resnet4.') return netdef.resnet4_G(opt.z_dim, opt.g_dim, opt.image_size, opt.channels, opt.norm_type, requires_grad, do_bias=not opt.do_no_bias) elif opt.g_model == FCn: return netdef.fcn_G(opt.z_dim, opt.g_dim, opt.image_size, opt.channels, requires_grad, depth=opt.g_depth) else: raise ValueError('g_model must be dcganx or fcn.')
def get_next(loader, iterator): if iterator is None: iterator = iter(loader) try: data = next(iterator) except StopIteration: logging( 'get_next: ... getting to the end of data ... starting over ...') iterator = iter(loader) data = next(iterator) return data, iterator
def _approximate(self, loader, g_loss): if self.verbose: timeLog('DDG::_approximate using %d data points ...' % len(loader.dataset)) self.check_trainability() with torch.no_grad(): g_params = clone_params(self.g_params, do_copy_requires_grad=True) optimizer = self.optim_config.create_optimizer(g_params) mtr_loss = tnt.meter.AverageValueMeter() last_loss_mean = 99999999 is_train = True for epoch in range(self.optim_config.cfg_x_epo): for sample in loader: z = cast(sample[0]) target_fake = cast(sample[1]) fake = self.g_net(z, g_params, is_train) loss = g_loss(fake, target_fake) mtr_loss.add(float(loss)) loss.backward() optimizer.step() optimizer.zero_grad() loss_mean = mtr_loss.value()[0] if self.verbose: logging('%d ... %s ... ' % (epoch, str(loss_mean))) if loss_mean > last_loss_mean: self.optim_config.reduce_lr_(optimizer) raise_if_nan(loss_mean) last_loss_mean = loss_mean mtr_loss.reset() copy_params(src=g_params, dst=self.g_params)