示例#1
0
    def forward(self, z, y):
        y_onehot = one_hot_embedding(y, self.num_classes)
        z_in = torch.cat([z, y_onehot], dim=1)
        output = self.fc(z_in)
        output = output.view(-1, self.z_dim, 1, 1)
        output =self.TS(output)
        output = pixel_norm(output)

        output = self.deconv1(output)
        output = self.BN_1(output)
        output = self.TS(output)
        output = pixel_norm(output)

        output = self.deconv2(output)
        output = self.BN_2(output)
        output = self.TS(output)
        output = pixel_norm(output)

        output = self.deconv3(output)
        output = self.BN_3(output)
        output = self.TS(output)
        output = pixel_norm(output)

        output = self.deconv4(output)
        output = self.outact(output)

        return output.view(-1, 32 * 32)
示例#2
0
文件: network.py 项目: remicres/sr4rs
 def _residule_block(x, dim, name):
     with tf.compat.v1.variable_scope(name):
         y = conv(x, dim, 3, 1, "conv1")
         y = lrelu(y)
         y = pixel_norm(y)
         y = conv(y, dim, 3, 1, "conv2")
         y = pixel_norm(y)
         return y + x
示例#3
0
def gblock(name, inputs, filters, data_format):
    with tf.variable_scope(name):
        x = ops.conv2d_up('conv_up', inputs, filters, 3, data_format)
        x = ops.leaky_relu(x)
        x = ops.pixel_norm(x, data_format)
        x = ops.conv2d('conv', x, filters, 3, data_format)
        x = ops.leaky_relu(x)
        x = ops.pixel_norm(x, data_format)
        return x
示例#4
0
 def forward(self, x):
     h = pixel_norm(x)
     if self.upsample:
         h = self.upsample(h)
         x = self.upsample(x)
     h = self.conv1(h)
     h = self.activation(pixel_norm(h))
     h = self.conv2(h)
     if self.learnable_sc:
         x = self.conv_sc(x)
     return h + x
示例#5
0
def generator(x,
              last_layer_resolution,
              cfg,
              is_training=True,
              scope='Generator'):
    def rname(resolution):
        return str(resolution) + 'x' + str(resolution)

    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        with tf.variable_scope("4x4"):
            fn4 = cfg.resolution_to_filt_num[4]
            x = ops.pixel_norm(x, cfg.data_format)
            x = ops.dense('dense', x, 4 * 4 * fn4, cfg.data_format)
            if cfg.data_format == 'NHWC':
                x = tf.reshape(x, [-1, 4, 4, fn4])
            else:
                x = tf.reshape(x, [-1, fn4, 4, 4])
            x = ops.leaky_relu(x)
            x = ops.pixel_norm(x, cfg.data_format)
            x = ops.conv2d('conv', x, fn4, 3, cfg.data_format)
            x = ops.leaky_relu(x)
            x = ops.pixel_norm(x, cfg.data_format)
        resolution = 8
        prev_x = None
        while resolution <= last_layer_resolution:
            filt_num = cfg.resolution_to_filt_num[resolution]
            prev_x = x
            x = gblock(rname(resolution), x, filt_num, cfg.data_format)
            resolution *= 2
        resolution = resolution // 2
        if resolution > cfg.starting_resolution:
            t = tf.get_variable(
                rname(resolution) + '_t',
                shape=[],
                collections=[tf.GraphKeys.GLOBAL_VARIABLES, "lerp"],
                dtype=tf.float32,
                initializer=tf.zeros_initializer(),
                trainable=False)
            x1 = ops.to_rgb('to_rgb_' + rname(resolution // 2), prev_x,
                            cfg.data_format)
            x1 = ops.upscale2d(x1, cfg.data_format)
            x2 = ops.to_rgb('to_rgb_' + rname(resolution), x, cfg.data_format)
            x = ops.lerp_clip(x1, x2, t)
        else:
            x = ops.to_rgb('to_rgb_' + rname(resolution), x, cfg.data_format)
        x_shape = utils.int_shape(x)
        assert (resolution == x_shape[1 if cfg.data_format == 'NHWC' else 3])
        assert (resolution == x_shape[2])
        return x
示例#6
0
    def forward(self, z, y):
        y_onehot = one_hot_embedding(y, self.num_classes)
        z_in = torch.cat([z, y_onehot], dim=1)
        output = self.fc(z_in)
        output = output.view(-1, 4 * self.model_dim, 4, 4)
        output = self.relu(output)
        output = pixel_norm(output)

        output = self.deconv1(output)
        output = output[:, :, :7, :7]
        output = self.relu(output)
        output = pixel_norm(output)

        output = self.deconv2(output)
        output = self.relu(output).contiguous()
        output = pixel_norm(output)

        output = self.deconv3(output)
        output = self.outact(output)
        return output.view(-1, IMG_W * IMG_H)
示例#7
0
 def forward(self, z, y):
     y_onehot = one_hot_embedding(y, self.num_classes)
     z_in = torch.cat([z, y_onehot], dim=1)
     output = self.fc(z_in)
     output = output.view(-1, 4 * self.model_dim, 4, 4)
     output = self.relu(output)
     output = pixel_norm(output)
     output = self.block1(output)
     output = self.block2(output)
     output = self.block3(output)
     output = self.outact(self.output(output))
     output = output[:, :, :-2, :-2]
     output = torch.reshape(output, [-1, IMG_H * IMG_W])
     return output
示例#8
0
 def test_pixel_norm_output(self):
     test_input = tf.constant([1., 2., 3., 4, 5., 6., 7., 8],
                              shape=[1, 2, 2, 2])
     output = pixel_norm(test_input)
     target = [
         1. / np.sqrt((1.**1 + 2.**2) / 2), 2. / np.sqrt(
             (1.**2 + 2.**2) / 2), 3. / np.sqrt((3.**2 + 4.**2) / 2),
         4. / np.sqrt((3.**2 + 4.**2) / 2), 5. / np.sqrt(
             (5.**2 + 6.**2) / 2), 6. / np.sqrt(
                 (5.**2 + 6.**2) / 2), 7. / np.sqrt(
                     (7.**2 + 8.**2) / 2), 8. / np.sqrt((7.**2 + 8.**2) / 2)
     ]
     target = np.reshape(target, [1, 2, 2, 2])
     self.assertAllClose(output, target)
示例#9
0
    def call(self, alpha, zs=None, intermediate_ws=None, mapping_network=None, cgan_w=None,
             crossover_list=None, random_crossover=False):
        """
        :param alpha:
        :param zs:
        :param intermediate_ws:
        :param mapping_network:
        :param cgan_w:
        :param crossover_list:
        :param random_crossover:
        :return:
        """
        intermediate_mode = (intermediate_ws is not None)
        mixing_mode = isinstance(zs, list) or isinstance(intermediate_ws, list)
        style_mixing = random_crossover or crossover_list is not None
        if zs is None and intermediate_ws is None:
            raise ValueError("Need z or intermediate")
        if self.use_mapping_network and (mapping_network is None and intermediate_ws is None):
            raise ValueError("No mapping network supplied to generator call")

        if not mixing_mode:
            if intermediate_mode:
                intermediate_ws = [intermediate_ws]
            else:
                zs = [zs]
        if not intermediate_mode:
            intermediate_ws = []
            for z in zs:
                z_shape = z.get_shape().as_list()
                if self.use_pixel_norm:
                    z = pixel_norm(z)  # todo: verify correct
                if len(z_shape) == 2:  # [batch size, z dim]
                    if self.map_cond and cgan_w is not None:
                        z = tf.concat([z, cgan_w], -1)
                    intermediate_latent = mapping_network(z)
                    z = tf.expand_dims(z, 1)
                    z = tf.expand_dims(z, 1)
                    #z = tf.reshape(z, [z_shape[0], 1, 1, -1])
                else:  # [batch size, 1, 1, z dim]
                    z_flat = tf.squeeze(z, axis=[1, 2])
                    if self.map_cond and cgan_w is not None:
                        z_flat = tf.concat([z_flat, cgan_w], -1)
                    intermediate_latent = mapping_network(z_flat)
                intermediate_ws.append(intermediate_latent)

        if len(intermediate_ws) > 1 and not random_crossover and crossover_list is None:
                raise ValueError("Need crossover for mixing mode")



        if cgan_w is not None and not self.map_cond:
           intermediate_latent_cond = tf.concat([intermediate_latent, cgan_w], -1)
        else:
           intermediate_latent_cond = None
        intermediate_latent_cond = None

        batch_size = tf.shape(intermediate_ws[0])[0]
        latent_size = tf.shape(intermediate_ws[0])[1]
        if self.learned_input is not None:
            z = tf.expand_dims(self.learned_input(None), axis=0)
            x = tf.tile(z, [batch_size, 1, 1, 1])
        else:
            x = tf.pad(z, [[0, 0], [3, 3], [3, 3], [0, 0]])
        if self.model_res_w == 1:  # for testing purposes
            return x
        current_res = self.start_shape[1]

        # Inefficient implementation, will have to redo
        with tf.name_scope("style_mixing"):
            intermediate_for_layer_list = []
            if random_crossover:
                intermediate_for_layer_list = []
                intermediate_mixing_schedule = tf.random.uniform([batch_size], 0, len(self.model_layers), dtype=tf.int32)
                intermediate_mixing_schedule = tf.transpose(
                    tf.one_hot(intermediate_mixing_schedule, depth=len(self.model_layers), dtype=tf.int32))
                intermediate_multiplier_for_current_layer = tf.zeros([batch_size], dtype=tf.int32)
                for i in range(0, len(self.model_layers)):
                    intermediate_multiplier_for_current_layer = tf.bitwise.bitwise_or(
                        intermediate_multiplier_for_current_layer,
                        intermediate_mixing_schedule[i])
                    intermediate_multiplier = tf.cast(intermediate_multiplier_for_current_layer,
                                                                        dtype=tf.float32)
                    intermediate_multiplier = tf.expand_dims(intermediate_multiplier, 1)
                    intermediate_for_layer_list.append(
                        (1-intermediate_multiplier)*intermediate_ws[0] +
                        intermediate_multiplier*intermediate_ws[1])
            elif crossover_list:
                for i in range(0, len(self.model_layers)):
                    intermediate_index = 0
                    for c in crossover_list:
                        if i >= c:
                            intermediate_index += 1
                    intermediate_for_layer_list.append(intermediate_ws[intermediate_index])
        to_rgb_lower = 0.
        layer_counter = 0
        # shape: [num_layers, batch_size, len(intermediate_w)]

        # for i in range(0, len(self.model_layers)):
        #     latents_to_swap = tf.random.categorical([batch_size, 2])
        #         ([batch_size, latent_size], minval=0, maxval=1, dtype=tf.int32, )
        #     intermediate_for_layer_list
        # if random_crossover:
        #     crossover_layer = tf.random_uniform([tf.shape(intermediate_ws[0])[0], 1], 0, len(self.model_layers),
        #                                         dtype=tf.int32)
        for conv1, noise1, bias1, tostyle1, conv2, noise2, bias2, tostyle2 in self.model_layers:
            with tf.name_scope("Res%d"%current_res):
                #apply_conditioning = intermediate_latent_cond is not None and \
                #    (self.cond_layers is None or
                #     layer_counter in self.cond_layers)
                apply_conditioning = False

                if (self.include_fmap_add_ops):
                    x += tf.zeros([tf.shape(x)], dtype=tf.float32, name="FmapRes%d")
                if layer_counter != 0 or self.learned_input is None:
                    x = conv1(x)
                if self.add_noise:
                    with tf.name_scope("noise_add1"):
                        noise_inputs = noise1(False)
                        assert(x.get_shape().as_list()[1:] == noise_inputs.get_shape().as_list()[1:])
                        x += noise_inputs
                x = bias1(x)
                x = tf.nn.leaky_relu(x, alpha=.2)
                if self.use_pixel_norm:
                    x = pixel_norm(x)

                if apply_conditioning:
                    ys, yb = tostyle1(intermediate_latent_cond)
                else:
                    if style_mixing:
                        ys, yb = tostyle1(intermediate_for_layer_list[layer_counter])
                    else:
                        ys, yb = tostyle1(intermediate_ws[0])
                x = adaptive_instance_norm(x, ys, yb)

                x = conv2(x)
                if self.use_pixel_norm:
                    x = pixel_norm(x)
                if self.add_noise:
                    with tf.name_scope("noise_add2"):
                        noise_inputs = noise2(False)
                        assert(x.get_shape().as_list()[1:] == noise_inputs.get_shape().as_list()[1:])
                        x += noise_inputs
                x = bias2(x)
                x = tf.nn.leaky_relu(x, alpha=.2)

                if apply_conditioning:
                    ys, yb = tostyle2(intermediate_latent_cond)
                else:
                    if style_mixing:
                        ys, yb = tostyle1(intermediate_for_layer_list[layer_counter])
                    else:
                        ys, yb = tostyle1(intermediate_ws[0])
                x = adaptive_instance_norm(x, ys, yb)

                if current_res == self.model_res_w // 2:
                    to_rgb_lower = upsample(self.toRGB_lower(x), method=self.resize_method)
                if current_res != self.model_res_w:
                    x = upsample(x, method=self.resize_method)
                layer_counter += 1
                current_res *= 2
        to_rgb = self.toRGB(x)
        output = to_rgb_lower + alpha * (to_rgb - to_rgb_lower)
        if self.output_res_w//self.model_res_w >= 2:
            output = upsample(output, method='nearest_neighbor',
                              factor=self.output_res_w//self.model_res_w)
        return output
示例#10
0
 def call(self, x):
     x = pixel_norm(x)
     for l in self.fc_layers:
         x = tf.nn.leaky_relu(l(x), alpha=.2)
     #  tf.summary.histogram("mapping_network_outputs", x)
     return x
示例#11
0
    def generate(self,
                 latent_var,
                 model_progressive_depth=1,
                 transition=False,
                 alpha_transition=0.0,
                 reuse=False):

        with tf.variable_scope('generator') as scope:

            if reuse:
                scope.reuse_variables()

            convs = []

            convs += [
                tf.reshape(latent_var,
                           [self.batch_size, 1, 1, self.latent_size])
            ]
            convs[-1] = pixel_norm(
                lrelu(
                    conv2d(convs[-1],
                           output_dim=self.get_filter_num(1),
                           k_h=4,
                           k_w=4,
                           d_w=1,
                           d_h=1,
                           padding='Other',
                           name='gen_n_1_conv')))

            convs += [
                tf.reshape(convs[-1],
                           [self.batch_size, 4, 4,
                            self.get_filter_num(1)])
            ]
            convs[-1] = pixel_norm(
                lrelu(
                    conv2d(convs[-1],
                           output_dim=self.get_filter_num(1),
                           d_w=1,
                           d_h=1,
                           name='gen_n_2_conv')))

            for i in range(model_progressive_depth - 1):

                if i == model_progressive_depth - 2 and transition:
                    # To RGB, low resolution
                    transition_conv = conv2d(convs[-1],
                                             output_dim=self.channels,
                                             k_w=1,
                                             k_h=1,
                                             d_w=1,
                                             d_h=1,
                                             name='gen_y_rgb_conv_{}'.format(
                                                 convs[-1].shape[1]))
                    transition_conv = upscale(transition_conv, 2)

                convs += [upscale(convs[-1], 2)]
                convs[-1] = pixel_norm(
                    lrelu(
                        conv2d(convs[-1],
                               output_dim=self.get_filter_num(i + 1),
                               d_w=1,
                               d_h=1,
                               name='gen_n_conv_1_{}'.format(
                                   convs[-1].shape[1]))))

                convs += [
                    pixel_norm(
                        lrelu(
                            conv2d(convs[-1],
                                   output_dim=self.get_filter_num(i + 1),
                                   d_w=1,
                                   d_h=1,
                                   name='gen_n_conv_2_{}'.format(
                                       convs[-1].shape[1]))))
                ]

            # To RGB, high resolution
            convs += [
                conv2d(convs[-1],
                       output_dim=self.channels,
                       k_w=1,
                       k_h=1,
                       d_w=1,
                       d_h=1,
                       name='gen_y_rgb_conv_{}'.format(convs[-1].shape[1]))
            ]

            if transition:
                convs[-1] = (1 - alpha_transition
                             ) * transition_conv + alpha_transition * convs[-1]

            return convs[-1]
示例#12
0
文件: network.py 项目: remicres/sr4rs
 def conv_bn(x, dim, ksize, name):
     y = conv(x, dim, ksize, 1, name)
     y = lrelu(y)
     y = pixel_norm(y)
     return y
示例#13
0
文件: network.py 项目: remicres/sr4rs
 def conv_upsample(x, dim, ksize, name):
     y = upscale2d_conv2d(x, dim, ksize, name)
     y = blur2d(y)
     y = lrelu(y)
     y = pixel_norm(y)
     return y
示例#14
0
文件: network.py 项目: remicres/sr4rs
def generator(lr_image, scope, nchannels, nresblocks, dim):
    """
    Generator
    """
    hr_images = dict()

    def conv_upsample(x, dim, ksize, name):
        y = upscale2d_conv2d(x, dim, ksize, name)
        y = blur2d(y)
        y = lrelu(y)
        y = pixel_norm(y)
        return y

    def _residule_block(x, dim, name):
        with tf.compat.v1.variable_scope(name):
            y = conv(x, dim, 3, 1, "conv1")
            y = lrelu(y)
            y = pixel_norm(y)
            y = conv(y, dim, 3, 1, "conv2")
            y = pixel_norm(y)
            return y + x

    def conv_bn(x, dim, ksize, name):
        y = conv(x, dim, ksize, 1, name)
        y = lrelu(y)
        y = pixel_norm(y)
        return y

    def _make_output(net, factor):
        hr_images[factor] = conv(net, nchannels, 1, 1, "output")

    with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
        with tf.compat.v1.variable_scope("encoder"):
            net = lrelu(conv(lr_image, dim, 9, 1, "conv1_9x9"))
            conv1 = net
            for i in range(nresblocks):
                net = _residule_block(net,
                                      dim=dim,
                                      name="ResBlock{}".format(i))

        with tf.compat.v1.variable_scope("res_1x"):
            net = conv(net, dim, 3, 1, "conv1")
            net = pixel_norm(net)
            net += conv1
            _make_output(net, factor=4)

        with tf.compat.v1.variable_scope("res_2x"):
            net = conv_upsample(net, 4 * dim, 3, "conv_upsample")
            net = conv_bn(net, 4 * dim, 3, "conv1")
            net = conv_bn(net, 4 * dim, 3, "conv2")
            net = conv_bn(net, 4 * dim, 5, "conv3")
            _make_output(net, factor=2)

        with tf.compat.v1.variable_scope("res_4x"):
            net = conv_upsample(net, 4 * dim, 3, "conv_upsample")
            net = conv_bn(net, 4 * dim, 3, "conv1")
            net = conv_bn(net, 4 * dim, 3, "conv2")
            net = conv_bn(net, 4 * dim, 9, "conv3")
            _make_output(net, factor=1)

        return hr_images