def forward(self, z, y): y_onehot = one_hot_embedding(y, self.num_classes) z_in = torch.cat([z, y_onehot], dim=1) output = self.fc(z_in) output = output.view(-1, self.z_dim, 1, 1) output =self.TS(output) output = pixel_norm(output) output = self.deconv1(output) output = self.BN_1(output) output = self.TS(output) output = pixel_norm(output) output = self.deconv2(output) output = self.BN_2(output) output = self.TS(output) output = pixel_norm(output) output = self.deconv3(output) output = self.BN_3(output) output = self.TS(output) output = pixel_norm(output) output = self.deconv4(output) output = self.outact(output) return output.view(-1, 32 * 32)
def _residule_block(x, dim, name): with tf.compat.v1.variable_scope(name): y = conv(x, dim, 3, 1, "conv1") y = lrelu(y) y = pixel_norm(y) y = conv(y, dim, 3, 1, "conv2") y = pixel_norm(y) return y + x
def gblock(name, inputs, filters, data_format): with tf.variable_scope(name): x = ops.conv2d_up('conv_up', inputs, filters, 3, data_format) x = ops.leaky_relu(x) x = ops.pixel_norm(x, data_format) x = ops.conv2d('conv', x, filters, 3, data_format) x = ops.leaky_relu(x) x = ops.pixel_norm(x, data_format) return x
def forward(self, x): h = pixel_norm(x) if self.upsample: h = self.upsample(h) x = self.upsample(x) h = self.conv1(h) h = self.activation(pixel_norm(h)) h = self.conv2(h) if self.learnable_sc: x = self.conv_sc(x) return h + x
def generator(x, last_layer_resolution, cfg, is_training=True, scope='Generator'): def rname(resolution): return str(resolution) + 'x' + str(resolution) with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): with tf.variable_scope("4x4"): fn4 = cfg.resolution_to_filt_num[4] x = ops.pixel_norm(x, cfg.data_format) x = ops.dense('dense', x, 4 * 4 * fn4, cfg.data_format) if cfg.data_format == 'NHWC': x = tf.reshape(x, [-1, 4, 4, fn4]) else: x = tf.reshape(x, [-1, fn4, 4, 4]) x = ops.leaky_relu(x) x = ops.pixel_norm(x, cfg.data_format) x = ops.conv2d('conv', x, fn4, 3, cfg.data_format) x = ops.leaky_relu(x) x = ops.pixel_norm(x, cfg.data_format) resolution = 8 prev_x = None while resolution <= last_layer_resolution: filt_num = cfg.resolution_to_filt_num[resolution] prev_x = x x = gblock(rname(resolution), x, filt_num, cfg.data_format) resolution *= 2 resolution = resolution // 2 if resolution > cfg.starting_resolution: t = tf.get_variable( rname(resolution) + '_t', shape=[], collections=[tf.GraphKeys.GLOBAL_VARIABLES, "lerp"], dtype=tf.float32, initializer=tf.zeros_initializer(), trainable=False) x1 = ops.to_rgb('to_rgb_' + rname(resolution // 2), prev_x, cfg.data_format) x1 = ops.upscale2d(x1, cfg.data_format) x2 = ops.to_rgb('to_rgb_' + rname(resolution), x, cfg.data_format) x = ops.lerp_clip(x1, x2, t) else: x = ops.to_rgb('to_rgb_' + rname(resolution), x, cfg.data_format) x_shape = utils.int_shape(x) assert (resolution == x_shape[1 if cfg.data_format == 'NHWC' else 3]) assert (resolution == x_shape[2]) return x
def forward(self, z, y): y_onehot = one_hot_embedding(y, self.num_classes) z_in = torch.cat([z, y_onehot], dim=1) output = self.fc(z_in) output = output.view(-1, 4 * self.model_dim, 4, 4) output = self.relu(output) output = pixel_norm(output) output = self.deconv1(output) output = output[:, :, :7, :7] output = self.relu(output) output = pixel_norm(output) output = self.deconv2(output) output = self.relu(output).contiguous() output = pixel_norm(output) output = self.deconv3(output) output = self.outact(output) return output.view(-1, IMG_W * IMG_H)
def forward(self, z, y): y_onehot = one_hot_embedding(y, self.num_classes) z_in = torch.cat([z, y_onehot], dim=1) output = self.fc(z_in) output = output.view(-1, 4 * self.model_dim, 4, 4) output = self.relu(output) output = pixel_norm(output) output = self.block1(output) output = self.block2(output) output = self.block3(output) output = self.outact(self.output(output)) output = output[:, :, :-2, :-2] output = torch.reshape(output, [-1, IMG_H * IMG_W]) return output
def test_pixel_norm_output(self): test_input = tf.constant([1., 2., 3., 4, 5., 6., 7., 8], shape=[1, 2, 2, 2]) output = pixel_norm(test_input) target = [ 1. / np.sqrt((1.**1 + 2.**2) / 2), 2. / np.sqrt( (1.**2 + 2.**2) / 2), 3. / np.sqrt((3.**2 + 4.**2) / 2), 4. / np.sqrt((3.**2 + 4.**2) / 2), 5. / np.sqrt( (5.**2 + 6.**2) / 2), 6. / np.sqrt( (5.**2 + 6.**2) / 2), 7. / np.sqrt( (7.**2 + 8.**2) / 2), 8. / np.sqrt((7.**2 + 8.**2) / 2) ] target = np.reshape(target, [1, 2, 2, 2]) self.assertAllClose(output, target)
def call(self, alpha, zs=None, intermediate_ws=None, mapping_network=None, cgan_w=None, crossover_list=None, random_crossover=False): """ :param alpha: :param zs: :param intermediate_ws: :param mapping_network: :param cgan_w: :param crossover_list: :param random_crossover: :return: """ intermediate_mode = (intermediate_ws is not None) mixing_mode = isinstance(zs, list) or isinstance(intermediate_ws, list) style_mixing = random_crossover or crossover_list is not None if zs is None and intermediate_ws is None: raise ValueError("Need z or intermediate") if self.use_mapping_network and (mapping_network is None and intermediate_ws is None): raise ValueError("No mapping network supplied to generator call") if not mixing_mode: if intermediate_mode: intermediate_ws = [intermediate_ws] else: zs = [zs] if not intermediate_mode: intermediate_ws = [] for z in zs: z_shape = z.get_shape().as_list() if self.use_pixel_norm: z = pixel_norm(z) # todo: verify correct if len(z_shape) == 2: # [batch size, z dim] if self.map_cond and cgan_w is not None: z = tf.concat([z, cgan_w], -1) intermediate_latent = mapping_network(z) z = tf.expand_dims(z, 1) z = tf.expand_dims(z, 1) #z = tf.reshape(z, [z_shape[0], 1, 1, -1]) else: # [batch size, 1, 1, z dim] z_flat = tf.squeeze(z, axis=[1, 2]) if self.map_cond and cgan_w is not None: z_flat = tf.concat([z_flat, cgan_w], -1) intermediate_latent = mapping_network(z_flat) intermediate_ws.append(intermediate_latent) if len(intermediate_ws) > 1 and not random_crossover and crossover_list is None: raise ValueError("Need crossover for mixing mode") if cgan_w is not None and not self.map_cond: intermediate_latent_cond = tf.concat([intermediate_latent, cgan_w], -1) else: intermediate_latent_cond = None intermediate_latent_cond = None batch_size = tf.shape(intermediate_ws[0])[0] latent_size = tf.shape(intermediate_ws[0])[1] if self.learned_input is not None: z = tf.expand_dims(self.learned_input(None), axis=0) x = tf.tile(z, [batch_size, 1, 1, 1]) else: x = tf.pad(z, [[0, 0], [3, 3], [3, 3], [0, 0]]) if self.model_res_w == 1: # for testing purposes return x current_res = self.start_shape[1] # Inefficient implementation, will have to redo with tf.name_scope("style_mixing"): intermediate_for_layer_list = [] if random_crossover: intermediate_for_layer_list = [] intermediate_mixing_schedule = tf.random.uniform([batch_size], 0, len(self.model_layers), dtype=tf.int32) intermediate_mixing_schedule = tf.transpose( tf.one_hot(intermediate_mixing_schedule, depth=len(self.model_layers), dtype=tf.int32)) intermediate_multiplier_for_current_layer = tf.zeros([batch_size], dtype=tf.int32) for i in range(0, len(self.model_layers)): intermediate_multiplier_for_current_layer = tf.bitwise.bitwise_or( intermediate_multiplier_for_current_layer, intermediate_mixing_schedule[i]) intermediate_multiplier = tf.cast(intermediate_multiplier_for_current_layer, dtype=tf.float32) intermediate_multiplier = tf.expand_dims(intermediate_multiplier, 1) intermediate_for_layer_list.append( (1-intermediate_multiplier)*intermediate_ws[0] + intermediate_multiplier*intermediate_ws[1]) elif crossover_list: for i in range(0, len(self.model_layers)): intermediate_index = 0 for c in crossover_list: if i >= c: intermediate_index += 1 intermediate_for_layer_list.append(intermediate_ws[intermediate_index]) to_rgb_lower = 0. layer_counter = 0 # shape: [num_layers, batch_size, len(intermediate_w)] # for i in range(0, len(self.model_layers)): # latents_to_swap = tf.random.categorical([batch_size, 2]) # ([batch_size, latent_size], minval=0, maxval=1, dtype=tf.int32, ) # intermediate_for_layer_list # if random_crossover: # crossover_layer = tf.random_uniform([tf.shape(intermediate_ws[0])[0], 1], 0, len(self.model_layers), # dtype=tf.int32) for conv1, noise1, bias1, tostyle1, conv2, noise2, bias2, tostyle2 in self.model_layers: with tf.name_scope("Res%d"%current_res): #apply_conditioning = intermediate_latent_cond is not None and \ # (self.cond_layers is None or # layer_counter in self.cond_layers) apply_conditioning = False if (self.include_fmap_add_ops): x += tf.zeros([tf.shape(x)], dtype=tf.float32, name="FmapRes%d") if layer_counter != 0 or self.learned_input is None: x = conv1(x) if self.add_noise: with tf.name_scope("noise_add1"): noise_inputs = noise1(False) assert(x.get_shape().as_list()[1:] == noise_inputs.get_shape().as_list()[1:]) x += noise_inputs x = bias1(x) x = tf.nn.leaky_relu(x, alpha=.2) if self.use_pixel_norm: x = pixel_norm(x) if apply_conditioning: ys, yb = tostyle1(intermediate_latent_cond) else: if style_mixing: ys, yb = tostyle1(intermediate_for_layer_list[layer_counter]) else: ys, yb = tostyle1(intermediate_ws[0]) x = adaptive_instance_norm(x, ys, yb) x = conv2(x) if self.use_pixel_norm: x = pixel_norm(x) if self.add_noise: with tf.name_scope("noise_add2"): noise_inputs = noise2(False) assert(x.get_shape().as_list()[1:] == noise_inputs.get_shape().as_list()[1:]) x += noise_inputs x = bias2(x) x = tf.nn.leaky_relu(x, alpha=.2) if apply_conditioning: ys, yb = tostyle2(intermediate_latent_cond) else: if style_mixing: ys, yb = tostyle1(intermediate_for_layer_list[layer_counter]) else: ys, yb = tostyle1(intermediate_ws[0]) x = adaptive_instance_norm(x, ys, yb) if current_res == self.model_res_w // 2: to_rgb_lower = upsample(self.toRGB_lower(x), method=self.resize_method) if current_res != self.model_res_w: x = upsample(x, method=self.resize_method) layer_counter += 1 current_res *= 2 to_rgb = self.toRGB(x) output = to_rgb_lower + alpha * (to_rgb - to_rgb_lower) if self.output_res_w//self.model_res_w >= 2: output = upsample(output, method='nearest_neighbor', factor=self.output_res_w//self.model_res_w) return output
def call(self, x): x = pixel_norm(x) for l in self.fc_layers: x = tf.nn.leaky_relu(l(x), alpha=.2) # tf.summary.histogram("mapping_network_outputs", x) return x
def generate(self, latent_var, model_progressive_depth=1, transition=False, alpha_transition=0.0, reuse=False): with tf.variable_scope('generator') as scope: if reuse: scope.reuse_variables() convs = [] convs += [ tf.reshape(latent_var, [self.batch_size, 1, 1, self.latent_size]) ] convs[-1] = pixel_norm( lrelu( conv2d(convs[-1], output_dim=self.get_filter_num(1), k_h=4, k_w=4, d_w=1, d_h=1, padding='Other', name='gen_n_1_conv'))) convs += [ tf.reshape(convs[-1], [self.batch_size, 4, 4, self.get_filter_num(1)]) ] convs[-1] = pixel_norm( lrelu( conv2d(convs[-1], output_dim=self.get_filter_num(1), d_w=1, d_h=1, name='gen_n_2_conv'))) for i in range(model_progressive_depth - 1): if i == model_progressive_depth - 2 and transition: # To RGB, low resolution transition_conv = conv2d(convs[-1], output_dim=self.channels, k_w=1, k_h=1, d_w=1, d_h=1, name='gen_y_rgb_conv_{}'.format( convs[-1].shape[1])) transition_conv = upscale(transition_conv, 2) convs += [upscale(convs[-1], 2)] convs[-1] = pixel_norm( lrelu( conv2d(convs[-1], output_dim=self.get_filter_num(i + 1), d_w=1, d_h=1, name='gen_n_conv_1_{}'.format( convs[-1].shape[1])))) convs += [ pixel_norm( lrelu( conv2d(convs[-1], output_dim=self.get_filter_num(i + 1), d_w=1, d_h=1, name='gen_n_conv_2_{}'.format( convs[-1].shape[1])))) ] # To RGB, high resolution convs += [ conv2d(convs[-1], output_dim=self.channels, k_w=1, k_h=1, d_w=1, d_h=1, name='gen_y_rgb_conv_{}'.format(convs[-1].shape[1])) ] if transition: convs[-1] = (1 - alpha_transition ) * transition_conv + alpha_transition * convs[-1] return convs[-1]
def conv_bn(x, dim, ksize, name): y = conv(x, dim, ksize, 1, name) y = lrelu(y) y = pixel_norm(y) return y
def conv_upsample(x, dim, ksize, name): y = upscale2d_conv2d(x, dim, ksize, name) y = blur2d(y) y = lrelu(y) y = pixel_norm(y) return y
def generator(lr_image, scope, nchannels, nresblocks, dim): """ Generator """ hr_images = dict() def conv_upsample(x, dim, ksize, name): y = upscale2d_conv2d(x, dim, ksize, name) y = blur2d(y) y = lrelu(y) y = pixel_norm(y) return y def _residule_block(x, dim, name): with tf.compat.v1.variable_scope(name): y = conv(x, dim, 3, 1, "conv1") y = lrelu(y) y = pixel_norm(y) y = conv(y, dim, 3, 1, "conv2") y = pixel_norm(y) return y + x def conv_bn(x, dim, ksize, name): y = conv(x, dim, ksize, 1, name) y = lrelu(y) y = pixel_norm(y) return y def _make_output(net, factor): hr_images[factor] = conv(net, nchannels, 1, 1, "output") with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE): with tf.compat.v1.variable_scope("encoder"): net = lrelu(conv(lr_image, dim, 9, 1, "conv1_9x9")) conv1 = net for i in range(nresblocks): net = _residule_block(net, dim=dim, name="ResBlock{}".format(i)) with tf.compat.v1.variable_scope("res_1x"): net = conv(net, dim, 3, 1, "conv1") net = pixel_norm(net) net += conv1 _make_output(net, factor=4) with tf.compat.v1.variable_scope("res_2x"): net = conv_upsample(net, 4 * dim, 3, "conv_upsample") net = conv_bn(net, 4 * dim, 3, "conv1") net = conv_bn(net, 4 * dim, 3, "conv2") net = conv_bn(net, 4 * dim, 5, "conv3") _make_output(net, factor=2) with tf.compat.v1.variable_scope("res_4x"): net = conv_upsample(net, 4 * dim, 3, "conv_upsample") net = conv_bn(net, 4 * dim, 3, "conv1") net = conv_bn(net, 4 * dim, 3, "conv2") net = conv_bn(net, 4 * dim, 9, "conv3") _make_output(net, factor=1) return hr_images