def _setup_solvers(self): # prepare training parameters with nn.parameter_scope('%s/discriminator' % self.scope): d_params = nn.get_parameters() with nn.parameter_scope('%s/generator' % self.scope): g_params = nn.get_parameters() # create solver for discriminator self.d_lr_scheduler = StepScheduler(self.d_lr, self.gamma, [self.lr_milestone]) self.d_solver = S.Adam(self.d_lr, beta1=self.beta1, beta2=0.999) self.d_solver.set_parameters(d_params) # create solver for generator self.g_lr_scheduler = StepScheduler(self.g_lr, self.gamma, [self.lr_milestone]) self.g_solver = S.Adam(self.g_lr, beta1=self.beta1, beta2=0.999) self.g_solver.set_parameters(g_params)
def _create_optimizer(ctx, o, networks, datasets): class Optimizer: pass optimizer = Optimizer() optimizer.comm = current_communicator() comm_size = optimizer.comm.size if optimizer.comm else 1 optimizer.start_iter = (o.start_iter - 1) // comm_size + \ 1 if o.start_iter > 0 else 0 optimizer.end_iter = (o.end_iter - 1) // comm_size + \ 1 if o.end_iter > 0 else 0 optimizer.name = o.name optimizer.order = o.order optimizer.update_interval = o.update_interval if o.update_interval > 0 else 1 optimizer.network = networks[o.network_name] optimizer.data_iterators = OrderedDict() for d in o.dataset_name: optimizer.data_iterators[d] = datasets[d].data_iterator optimizer.dataset_assign = OrderedDict() for d in o.data_variable: optimizer.dataset_assign[optimizer.network.variables[ d.variable_name]] = d.data_name optimizer.generator_assign = OrderedDict() for g in o.generator_variable: optimizer.generator_assign[optimizer.network.variables[ g.variable_name]] = _get_generator(g) optimizer.loss_variables = [] for l in o.loss_variable: optimizer.loss_variables.append( optimizer.network.variables[l.variable_name]) optimizer.parameter_learning_rate_multipliers = OrderedDict() for p in o.parameter_variable: param_variable_names = _get_matching_variable_names( p.variable_name, optimizer.network.variables.keys()) for v_name in param_variable_names: optimizer.parameter_learning_rate_multipliers[ optimizer.network. variables[v_name]] = p.learning_rate_multiplier with nn.context_scope(ctx): if o.solver.type == 'Adagrad': optimizer.solver = S.Adagrad(o.solver.adagrad_param.lr, o.solver.adagrad_param.eps) init_lr = o.solver.adagrad_param.lr elif o.solver.type == 'Adadelta': optimizer.solver = S.Adadelta(o.solver.adadelta_param.lr, o.solver.adadelta_param.decay, o.solver.adadelta_param.eps) init_lr = o.solver.adadelta_param.lr elif o.solver.type == 'Adam': optimizer.solver = S.Adam(o.solver.adam_param.alpha, o.solver.adam_param.beta1, o.solver.adam_param.beta2, o.solver.adam_param.eps) init_lr = o.solver.adam_param.alpha elif o.solver.type == 'Adamax': optimizer.solver = S.Adamax(o.solver.adamax_param.alpha, o.solver.adamax_param.beta1, o.solver.adamax_param.beta2, o.solver.adamax_param.eps) init_lr = o.solver.adamax_param.alpha elif o.solver.type == 'AdaBound': optimizer.solver = S.AdaBound(o.solver.adabound_param.alpha, o.solver.adabound_param.beta1, o.solver.adabound_param.beta2, o.solver.adabound_param.eps, o.solver.adabound_param.final_lr, o.solver.adabound_param.gamma) init_lr = o.solver.adabound_param.alpha elif o.solver.type == 'AMSGRAD': optimizer.solver = S.AMSGRAD(o.solver.amsgrad_param.alpha, o.solver.amsgrad_param.beta1, o.solver.amsgrad_param.beta2, o.solver.amsgrad_param.eps) init_lr = o.solver.amsgrad_param.alpha elif o.solver.type == 'AMSBound': optimizer.solver = S.AMSBound(o.solver.amsbound_param.alpha, o.solver.amsbound_param.beta1, o.solver.amsbound_param.beta2, o.solver.amsbound_param.eps, o.solver.amsbound_param.final_lr, o.solver.amsbound_param.gamma) init_lr = o.solver.amsbound_param.alpha elif o.solver.type == 'Eve': p = o.solver.eve_param optimizer.solver = S.Eve(p.alpha, p.beta1, p.beta2, p.beta3, p.k, p.k2, p.eps) init_lr = p.alpha elif o.solver.type == 'Momentum': optimizer.solver = S.Momentum(o.solver.momentum_param.lr, o.solver.momentum_param.momentum) init_lr = o.solver.momentum_param.lr elif o.solver.type == 'Nesterov': optimizer.solver = S.Nesterov(o.solver.nesterov_param.lr, o.solver.nesterov_param.momentum) init_lr = o.solver.nesterov_param.lr elif o.solver.type == 'RMSprop': optimizer.solver = S.RMSprop(o.solver.rmsprop_param.lr, o.solver.rmsprop_param.decay, o.solver.rmsprop_param.eps) init_lr = o.solver.rmsprop_param.lr elif o.solver.type == 'Sgd' or o.solver.type == 'SGD': optimizer.solver = S.Sgd(o.solver.sgd_param.lr) init_lr = o.solver.sgd_param.lr else: raise ValueError('Solver "' + o.solver.type + '" is not supported.') parameters = { v.name: v.variable_instance for v, local_lr in optimizer.parameter_learning_rate_multipliers.items() if local_lr > 0.0 } optimizer.solver.set_parameters(parameters) optimizer.parameters = OrderedDict( sorted(parameters.items(), key=lambda x: x[0])) optimizer.weight_decay = o.solver.weight_decay # keep following 2 lines for backward compatibility optimizer.lr_decay = o.solver.lr_decay if o.solver.lr_decay > 0.0 else 1.0 optimizer.lr_decay_interval = o.solver.lr_decay_interval if o.solver.lr_decay_interval > 0 else 1 optimizer.solver.set_states_from_protobuf(o) optimizer.comm = current_communicator() comm_size = optimizer.comm.size if optimizer.comm else 1 optimizer.scheduler = ExponentialScheduler(init_lr, 1.0, 1) if o.solver.lr_scheduler_type == 'Polynomial': if o.solver.polynomial_scheduler_param.power != 0.0: optimizer.scheduler = PolynomialScheduler( init_lr, o.solver.polynomial_scheduler_param.max_iter // comm_size, o.solver.polynomial_scheduler_param.power) elif o.solver.lr_scheduler_type == 'Cosine': optimizer.scheduler = CosineScheduler( init_lr, o.solver.cosine_scheduler_param.max_iter // comm_size) elif o.solver.lr_scheduler_type == 'Exponential': if o.solver.exponential_scheduler_param.gamma != 1.0: optimizer.scheduler = ExponentialScheduler( init_lr, o.solver.exponential_scheduler_param.gamma, o.solver.exponential_scheduler_param.iter_interval // comm_size if o.solver.exponential_scheduler_param.iter_interval > comm_size else 1) elif o.solver.lr_scheduler_type == 'Step': if o.solver.step_scheduler_param.gamma != 1.0 and len( o.solver.step_scheduler_param.iter_steps) > 0: optimizer.scheduler = StepScheduler( init_lr, o.solver.step_scheduler_param.gamma, [ step // comm_size for step in o.solver.step_scheduler_param.iter_steps ]) elif o.solver.lr_scheduler_type == 'Custom': # ToDo raise NotImplementedError() elif o.solver.lr_scheduler_type == '': if o.solver.lr_decay_interval != 0 or o.solver.lr_decay != 0.0: optimizer.scheduler = ExponentialScheduler( init_lr, o.solver.lr_decay if o.solver.lr_decay > 0.0 else 1.0, o.solver.lr_decay_interval // comm_size if o.solver.lr_decay_interval > comm_size else 1) else: raise ValueError('Learning Rate Scheduler "' + o.solver.lr_scheduler_type + '" is not supported.') if o.solver.lr_warmup_scheduler_type == 'Linear': if o.solver.linear_warmup_scheduler_param.warmup_iter >= comm_size: optimizer.scheduler = LinearWarmupScheduler( optimizer.scheduler, o.solver.linear_warmup_scheduler_param.warmup_iter // comm_size) optimizer.forward_sequence = optimizer.network.get_forward_sequence( optimizer.loss_variables) optimizer.backward_sequence = optimizer.network.get_backward_sequence( optimizer.loss_variables, optimizer.parameter_learning_rate_multipliers) return optimizer
class Model: def __init__(self, real, num_layer, fs, min_fs, kernel, pad, lam_grad, alpha_recon, d_lr, g_lr, beta1, gamma, lr_milestone, scope, test=False): self.real = real self.num_layer = num_layer self.fs = fs self.min_fs = min_fs self.kernel = kernel self.pad = pad self.lam_grad = lam_grad self.alpha_recon = alpha_recon self.d_lr = d_lr self.g_lr = g_lr self.beta1 = beta1 self.gamma = gamma self.lr_milestone = lr_milestone self.scope = scope self.test = test self._build() self._setup_solvers() # for generating larger images than the ones in training time def get_generator_func(self, image_shape, channel_first=True): if not channel_first: image_shape = (image_shape[2], image_shape[0], image_shape[1]) generator_fn, _ = self._network_funcs() # build inference graph x_ = nn.Variable((1,) + image_shape) y_ = nn.Variable((1,) + image_shape) fake = generator_fn(x=x, y=y) def func(x, y): x_.d = x y_.d = y fake.forward(clear_buffer=True) return fake.d.copy() return func def _network_funcs(self): # generator model generator_fn = partial(generator, num_layer=self.num_layer, fs=self.fs, min_fs=self.min_fs, kernel=self.kernel, pad=self.pad, scope='%s/generator' % self.scope, test=self.test) # discriminator model discriminator_fn = partial(discriminator, num_layer=self.num_layer, fs=self.fs, min_fs=self.min_fs, kernel=self.kernel, pad=self.pad, scope='%s/discriminator' % self.scope, test=self.test) return generator_fn, discriminator_fn def _build(self): generator_fn, discriminator_fn = self._network_funcs() # real shape ch, w, h = self.real.shape[1:] # inputs self.x = nn.Variable((1, ch, w, h)) self.y = nn.Variable((1, ch, w, h)) self.rec_x = nn.Variable((1, ch, w, h)) self.rec_y = nn.Variable((1, ch, w, h)) y_real = nn.Variable.from_numpy_array(self.real) y_real.persistent = True # padding inputs padded_x = _pad(self.x, self.kernel, self.num_layer) padded_rec_x = _pad(self.rec_x, self.kernel, self.num_layer) # generate fake image self.fake = generator_fn(x=padded_x, y=self.y) fake_without_grads = F.identity(self.fake) fake_without_grads.need_grad = False rec = generator_fn(x=padded_rec_x, y=self.rec_y) # discriminate images p_real = discriminator_fn(x=y_real) p_fake = discriminator_fn(x=self.fake) p_fake_without_grads = discriminator_fn(x=fake_without_grads) # gradient penalty for discriminator grad_penalty = _calc_gradient_penalty(y_real, fake_without_grads, discriminator_fn) # discriminator loss self.d_real_error = -F.mean(p_real) self.d_fake_error = F.mean(p_fake_without_grads) self.d_error = self.d_real_error + self.d_fake_error \ + self.lam_grad * grad_penalty # generator loss self.rec_error = F.mean(F.squared_error(rec, y_real)) self.g_fake_error = -F.mean(p_fake) self.g_error = self.g_fake_error + self.alpha_recon * self.rec_error def _setup_solvers(self): # prepare training parameters with nn.parameter_scope('%s/discriminator' % self.scope): d_params = nn.get_parameters() with nn.parameter_scope('%s/generator' % self.scope): g_params = nn.get_parameters() # create solver for discriminator self.d_lr_scheduler = StepScheduler(self.d_lr, self.gamma, [self.lr_milestone]) self.d_solver = S.Adam(self.d_lr, beta1=self.beta1, beta2=0.999) self.d_solver.set_parameters(d_params) # create solver for generator self.g_lr_scheduler = StepScheduler(self.g_lr, self.gamma, [self.lr_milestone]) self.g_solver = S.Adam(self.g_lr, beta1=self.beta1, beta2=0.999) self.g_solver.set_parameters(g_params) def generate(self, x, y): self.x.d = x self.y.d = y self.fake.forward(clear_buffer=True) return self.fake.d def update_g(self, epoch, x, y, rec_x, rec_y): self.x.d = x self.y.d = y self.rec_x.d = rec_x self.rec_y.d = rec_y self.g_error.forward() fake_error = self.g_fake_error.d.copy() rec_error = self.rec_error.d.copy() self.g_solver.zero_grad() self.g_error.backward(clear_buffer=True) lr = self.g_lr_scheduler.get_learning_rate(epoch) self.g_solver.set_learning_rate(lr) self.g_solver.update() return fake_error, rec_error def update_d(self, epoch, x, y): self.x.d = x self.y.d = y self.d_error.forward() real_error = self.d_real_error.d.copy() fake_error = self.d_fake_error.d.copy() self.d_solver.zero_grad() self.d_error.backward(clear_buffer=True) lr = self.d_lr_scheduler.get_learning_rate(epoch) self.d_solver.set_learning_rate(lr) self.d_solver.update() return fake_error, real_error