def resblock_g(h, y, scopename, n_classes, maps, kernel=(3, 3), pad=(1, 1), stride=(1, 1), upsample=True, test=False, sn=True, coefs=[1.0]): """Residual block for generator""" s = h _, c, _, _ = h.shape with nn.parameter_scope(scopename): # BN -> Relu -> Upsample -> Conv with nn.parameter_scope("conv1"): h = CCBN(h, y, n_classes, test=test, coefs=coefs) h = F.relu(h, inplace=True) if upsample: h = F.unpooling(h, kernel=(2, 2)) h = convolution(h, maps, kernel=kernel, pad=pad, stride=stride, with_bias=True, sn=sn, test=test, init_scale=np.sqrt(2)) # BN -> Relu -> Conv with nn.parameter_scope("conv2"): h = CCBN(h, y, n_classes, test=test, coefs=coefs) h = F.relu(h, inplace=True) h = convolution(h, maps, kernel=kernel, pad=pad, stride=stride, with_bias=True, sn=sn, test=test, init_scale=np.sqrt(2)) # Shortcut: Upsample -> Conv if upsample: s = F.unpooling(s, kernel=(2, 2)) if c != maps or upsample: with nn.parameter_scope("shortcut"): s = convolution(s, maps, kernel=(1, 1), pad=(0, 0), stride=(1, 1), with_bias=True, sn=sn, test=test) return F.add2(h, s, True)
def backward_impl(self, inputs, outputs, prop_down, accum): # inputs: [inputs_fwd_graph] + [inputs_bwd_graph] or # [inputs_fwd_graph] + [outputs_fwd_graph] + [inputs_bwd_graph] # Args kernel = self.forward_func.info.args["kernel"] # Inputs x0 = inputs[0].data dy = inputs[1].data # Outputs dx0 = outputs[0].data # Grads of inputs g_x0 = inputs[0].grad g_dy = inputs[1].grad # Grads of outputs g_dx0 = outputs[0].grad # Compute if prop_down[1]: # Optimize by creating max_pooling with indeces g_dy_ = F.unpooling(g_dx0, kernel) if accum[1]: g_dy += g_dy_ else: g_dy.copy_from(g_dy_)
def upsample(x, maps, norm="ln", pad_mode="reflect", name="upsample"): h = x with nn.parameter_scope(name): #h = F.interpolate(h, (2, 2), mode="linear") h = F.unpooling(h, (2, 2)) h = convblock(h, maps, 5, 2, 1, norm=norm, pad_mode=pad_mode) return h
def cnn(self, h, resolution, channel, test): """CNN block The following operations are performed two times. 1. Upsampling 2. Conv 3. Pixel-wise normalization 4. Relu """ h = F.unpooling(h, kernel=(2, 2)) with nn.parameter_scope("phase_{}".format(resolution)): with nn.parameter_scope("conv1"): h = conv(h, channel, kernel=(3, 3), pad=(1, 1), stride=(1, 1), with_bias=not self.use_bn, use_wscale=self.use_wscale, use_he_backward=self.use_he_backward) h = pixel_wise_feature_vector_normalization( BN(h, use_bn=self.use_bn, test=test)) h = self.activation(h) with nn.parameter_scope("conv2"): h = conv(h, channel, kernel=(3, 3), pad=(1, 1), stride=(1, 1), with_bias=not self.use_bn, use_wscale=self.use_wscale, use_he_backward=self.use_he_backward) h = pixel_wise_feature_vector_normalization( BN(h, use_bn=self.use_bn, test=test)) h = self.activation(h) return h
def unpool_block(x, n=0, k=(4, 4), s=(2, 2), p=(1, 1), leaky=False, unpool=False, init_method=None): if not unpool: logger.info("Deconvolution was used.") x = deconvolution(x, n=n, kernel=k, stride=s, pad=p, init_method=init_method) else: logger.info("Unpooling was used.") x = F.unpooling(x, kernel=(2, 2)) x = convolution(x, n, kernel=(3, 3), stride=(1, 1), pad=(1, 1), init_method=init_method) x = instance_normalization(x, fix_parameters=True) x = F.leaky_relu(x, alpha=0.2) if leaky else F.relu(x) return x
def upsampling_block(x, i): with nn.parameter_scope( ('us_block-%2d' % i) ): up = F.unpooling(af(x), (2,)) cac_x = crop_and_concat(ds_outputs[-i-1], up) us = af(conv(cac_x, num_initial_filters+num_initial_filters*(num_layers-i-1), (merge_filter_size,), (2,), name='conv')) return us
def unpooling_data_grad_backward(inputs, kernel, channel_last=False): """ Args: inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function. kwargs (dict of arguments): Dictionary of the corresponding function arguments. Return: list of Variable: Return the gradients wrt inputs of the corresponding function. """ gdx = inputs[0] gdy = F.unpooling(gdx, kernel, channel_last) return gdy
def upsample(h, maps, up, test=False, name="convblock"): if up == "nearest": h = PF.convolution(h, maps, (3, 3), (1, 1), name=name) h = F.interpolate(h, scale=(2, 2), mode="nearest") elif up == "linear": h = PF.convolution(h, maps, (3, 3), (1, 1), name=name) h = F.interpolate(h, scale=(2, 2), mode="linear") elif up == "unpooling": h = PF.convolution(h, maps, (3, 3), (1, 1), name=name) h = F.unpooling(h, (2, 2)) elif up == "deconv": h = PF.deconvolution(h, maps * 2, (2, 2), (0, 0), (2, 2), name=name) else: raise ValueError( 'Set "up" option in ["nearest", "linear", "unpooling", "deconv"]') h = PF.batch_normalization(h, batch_stat=not test, name=name) h = F.relu(h) return h
def transition_cnn(self, h, pre_resolution, nxt_resolution, pre_channel, nxt_channel, alpha, test): lhs = self.to_RGB(F.unpooling(h, kernel=(2, 2)), pre_resolution) rhs = self.to_RGB(self.cnn(h, nxt_resolution, nxt_channel, test), nxt_resolution) return (1 - alpha) * lhs + alpha * rhs
def _transition(self, ecpoch_per_resolution): batch_size = self.di.batch_size resolution = self.gen.resolution_list[-1] phase = "{}to{}".format( self.gen.resolution_list[-2], self.gen.resolution_list[-1]) logger.info("phase : {}".format(phase)) kernel_size = self.resolution_list[-1] // resolution kernel = (kernel_size, kernel_size) total_itr = (self.di.size // batch_size + 1) * ecpoch_per_resolution global_itr = 1. alpha = global_itr / total_itr for epoch in range(ecpoch_per_resolution): logger.info("epoch : {}".format(epoch + 1)) itr = 0 current_epoch = self.di.epoch while self.di.epoch == current_epoch: img, _ = self.di.next() x = nn.Variable.from_numpy_array(img) z = F.randn(shape=(batch_size, self.n_latent, 1, 1)) z = pixel_wise_feature_vector_normalization( z) if self.hyper_sphere else z y = self.gen.transition(z, alpha, test=True) y.unlinked() y.need_grad = False x_r = F.average_pooling(x, kernel=kernel) p_real = self.dis.transition(x_r, alpha) p_fake = self.dis.transition(y, alpha) loss_dis = F.mean(F.pow_scalar((p_real - 1), 2.) + F.pow_scalar(p_fake, 2.) * self.l2_fake_weight) if itr % self.n_critic + 1 == self.n_critic: with nn.parameter_scope("discriminator"): self.solver_dis.set_parameters(nn.get_parameters(), reset=False, retain_state=True) self.solver_dis.zero_grad() loss_dis.backward(clear_buffer=True) self.solver_dis.update() z = F.randn(shape=(batch_size, self.n_latent, 1, 1)) z = pixel_wise_feature_vector_normalization( z) if self.hyper_sphere else z y = self.gen.transition(z, alpha, test=False) p_fake = self.dis.transition(y, alpha) loss_gen = F.mean(F.pow_scalar((p_fake - 1), 2)) with nn.parameter_scope("generator"): self.solver_gen.set_parameters( nn.get_parameters(), reset=False, retain_state=True) self.solver_gen.zero_grad() loss_gen.backward(clear_buffer=True) self.solver_gen.update() itr += 1 global_itr += 1. alpha = global_itr / total_itr if epoch % self.save_image_interval + 1 == self.save_image_interval: z = nn.Variable.from_numpy_array(self.z_test) z = pixel_wise_feature_vector_normalization( z) if self.hyper_sphere else z y = self.gen.transition(z, alpha) img_name = "phase_{}_epoch_{}".format(phase, epoch + 1) self.monitor_image_tile.add( img_name, F.unpooling(y, kernel=kernel))
def _train(self, ecpoch_per_resolution, each_save=False): batch_size = self.di.batch_size resolution = self.gen.resolution_list[-1] logger.info("phase : {}".format(resolution)) kernel_size = self.resolution_list[-1] // resolution kernel = (kernel_size, kernel_size) img_name = "original_phase_{}".format(resolution) img, _ = self.di.next() self.monitor_image_tile.add(img_name, img) for epoch in range(ecpoch_per_resolution): logger.info("epoch : {}".format(epoch + 1)) itr = 0 current_epoch = self.di.epoch while self.di.epoch == current_epoch: img, _ = self.di.next() x = nn.Variable.from_numpy_array(img) z = F.randn(shape=(batch_size, self.n_latent, 1, 1)) z = pixel_wise_feature_vector_normalization( z) if self.hyper_sphere else z y = self.gen(z, test=True) y.unlinked() y.need_grad = False x_r = F.average_pooling(x, kernel=kernel) p_real = self.dis(x_r) p_fake = self.dis(y) p_real.persistent, p_fake.persistent = True, True loss_dis = F.mean(F.pow_scalar((p_real - 1), 2.) + F.pow_scalar(p_fake, 2.) * self.l2_fake_weight) loss_dis.persistent = True if itr % self.n_critic + 1 == self.n_critic: with nn.parameter_scope("discriminator"): self.solver_dis.set_parameters(nn.get_parameters(), reset=False, retain_state=True) self.solver_dis.zero_grad() loss_dis.backward(clear_buffer=True) self.solver_dis.update() z = F.randn(shape=(batch_size, self.n_latent, 1, 1)) z = pixel_wise_feature_vector_normalization( z) if self.hyper_sphere else z y = self.gen(z, test=False) p_fake = self.dis(y) p_fake.persistent = True loss_gen = F.mean(F.pow_scalar((p_fake - 1), 2.)) loss_gen.persistent = True with nn.parameter_scope("generator"): self.solver_gen.set_parameters(nn.get_parameters(), reset=False, retain_state=True) self.solver_gen.zero_grad() loss_gen.backward(clear_buffer=True) self.solver_gen.update() # Monitor self.monitor_p_real.add( self.global_itr, p_real.d.copy().mean()) self.monitor_p_fake.add( self.global_itr, p_fake.d.copy().mean()) self.monitor_loss_dis.add(self.global_itr, loss_dis.d.copy()) self.monitor_loss_gen.add(self.global_itr, loss_gen.d.copy()) itr += 1 self.global_itr += 1 if epoch % self.save_image_interval + 1 == self.save_image_interval: z = nn.Variable.from_numpy_array(self.z_test) z = pixel_wise_feature_vector_normalization( z) if self.hyper_sphere else z y = self.gen(z, test=True) img_name = "phase_{}_epoch_{}".format(resolution, epoch + 1) self.monitor_image_tile.add( img_name, F.unpooling(y, kernel=kernel)) if each_save: self.gen.save_parameters(self.monitor_path, "Gen_phase_{}_epoch_{}".format( self.resolution_list[-1], epoch+1)) self.dis.save_parameters(self.monitor_path, "Dis_phase_{}_epoch_{}".format( self.resolution_list[-1], epoch+1))