Example #1
0
def resblock_g(h, y, scopename,
               n_classes, maps, kernel=(3, 3), pad=(1, 1), stride=(1, 1),
               upsample=True, test=False, sn=True, coefs=[1.0]):
    """Residual block for generator"""

    s = h
    _, c, _, _ = h.shape
    with nn.parameter_scope(scopename):
        # BN -> Relu -> Upsample -> Conv
        with nn.parameter_scope("conv1"):
            h = CCBN(h, y, n_classes, test=test, coefs=coefs)
            h = F.relu(h, inplace=True)
            if upsample:
                h = F.unpooling(h, kernel=(2, 2))
            h = convolution(h, maps, kernel=kernel, pad=pad, stride=stride,
                            with_bias=True, sn=sn, test=test, init_scale=np.sqrt(2))

        # BN -> Relu -> Conv
        with nn.parameter_scope("conv2"):
            h = CCBN(h, y, n_classes, test=test, coefs=coefs)
            h = F.relu(h, inplace=True)
            h = convolution(h, maps, kernel=kernel, pad=pad, stride=stride,
                            with_bias=True, sn=sn, test=test, init_scale=np.sqrt(2))

        # Shortcut: Upsample -> Conv
        if upsample:
            s = F.unpooling(s, kernel=(2, 2))
        if c != maps or upsample:
            with nn.parameter_scope("shortcut"):
                s = convolution(s, maps, kernel=(1, 1), pad=(0, 0), stride=(1, 1),
                                with_bias=True, sn=sn, test=test)
    return F.add2(h, s, True)
Example #2
0
    def backward_impl(self, inputs, outputs, prop_down, accum):
        # inputs: [inputs_fwd_graph] + [inputs_bwd_graph] or
        # [inputs_fwd_graph] + [outputs_fwd_graph] + [inputs_bwd_graph]

        # Args
        kernel = self.forward_func.info.args["kernel"]

        # Inputs
        x0 = inputs[0].data
        dy = inputs[1].data
        # Outputs
        dx0 = outputs[0].data
        # Grads of inputs
        g_x0 = inputs[0].grad
        g_dy = inputs[1].grad
        # Grads of outputs
        g_dx0 = outputs[0].grad

        # Compute
        if prop_down[1]:
            # Optimize by creating max_pooling with indeces
            g_dy_ = F.unpooling(g_dx0, kernel)
            if accum[1]:
                g_dy += g_dy_
            else:
                g_dy.copy_from(g_dy_)
Example #3
0
def upsample(x, maps, norm="ln", pad_mode="reflect", name="upsample"):
    h = x
    with nn.parameter_scope(name):
        #h = F.interpolate(h, (2, 2), mode="linear")
        h = F.unpooling(h, (2, 2))
        h = convblock(h, maps, 5, 2, 1, norm=norm, pad_mode=pad_mode)
    return h
Example #4
0
    def cnn(self, h, resolution, channel, test):
        """CNN block

        The following operations are performed two times.

        1. Upsampling
        2. Conv
        3. Pixel-wise normalization
        4. Relu
        """
        h = F.unpooling(h, kernel=(2, 2))
        with nn.parameter_scope("phase_{}".format(resolution)):
            with nn.parameter_scope("conv1"):
                h = conv(h, channel, kernel=(3, 3), pad=(1, 1), stride=(1, 1),
                         with_bias=not self.use_bn,
                         use_wscale=self.use_wscale,
                         use_he_backward=self.use_he_backward)
                h = pixel_wise_feature_vector_normalization(
                    BN(h, use_bn=self.use_bn, test=test))
                h = self.activation(h)
            with nn.parameter_scope("conv2"):
                h = conv(h, channel, kernel=(3, 3), pad=(1, 1), stride=(1, 1),
                         with_bias=not self.use_bn,
                         use_wscale=self.use_wscale,
                         use_he_backward=self.use_he_backward)
                h = pixel_wise_feature_vector_normalization(
                    BN(h, use_bn=self.use_bn, test=test))
                h = self.activation(h)
        return h
Example #5
0
def unpool_block(x,
                 n=0,
                 k=(4, 4),
                 s=(2, 2),
                 p=(1, 1),
                 leaky=False,
                 unpool=False,
                 init_method=None):
    if not unpool:
        logger.info("Deconvolution was used.")
        x = deconvolution(x,
                          n=n,
                          kernel=k,
                          stride=s,
                          pad=p,
                          init_method=init_method)
    else:
        logger.info("Unpooling was used.")
        x = F.unpooling(x, kernel=(2, 2))
        x = convolution(x,
                        n,
                        kernel=(3, 3),
                        stride=(1, 1),
                        pad=(1, 1),
                        init_method=init_method)
    x = instance_normalization(x, fix_parameters=True)
    x = F.leaky_relu(x, alpha=0.2) if leaky else F.relu(x)
    return x
    def upsampling_block(x, i):

        with nn.parameter_scope( ('us_block-%2d' % i) ):
            up = F.unpooling(af(x), (2,))
            cac_x = crop_and_concat(ds_outputs[-i-1], up)
            us = af(conv(cac_x, num_initial_filters+num_initial_filters*(num_layers-i-1), (merge_filter_size,), (2,), name='conv'))
            return us
Example #7
0
def unpooling_data_grad_backward(inputs, kernel, channel_last=False):
    """
    Args:
      inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function.
      kwargs (dict of arguments): Dictionary of the corresponding function arguments.

    Return:
      list of Variable: Return the gradients wrt inputs of the corresponding function.
    """
    gdx = inputs[0]
    gdy = F.unpooling(gdx, kernel, channel_last)
    return gdy
Example #8
0
def upsample(h, maps, up, test=False, name="convblock"):
    if up == "nearest":
        h = PF.convolution(h, maps, (3, 3), (1, 1), name=name)
        h = F.interpolate(h, scale=(2, 2), mode="nearest")
    elif up == "linear":
        h = PF.convolution(h, maps, (3, 3), (1, 1), name=name)
        h = F.interpolate(h, scale=(2, 2), mode="linear")
    elif up == "unpooling":
        h = PF.convolution(h, maps, (3, 3), (1, 1), name=name)
        h = F.unpooling(h, (2, 2))
    elif up == "deconv":
        h = PF.deconvolution(h, maps * 2, (2, 2), (0, 0), (2, 2), name=name)
    else:
        raise ValueError(
            'Set "up" option in ["nearest", "linear", "unpooling", "deconv"]')
    h = PF.batch_normalization(h, batch_stat=not test, name=name)
    h = F.relu(h)

    return h
Example #9
0
 def transition_cnn(self, h, pre_resolution, nxt_resolution, pre_channel,
                    nxt_channel, alpha, test):
     lhs = self.to_RGB(F.unpooling(h, kernel=(2, 2)), pre_resolution)
     rhs = self.to_RGB(self.cnn(h, nxt_resolution, nxt_channel, test),
                       nxt_resolution)
     return (1 - alpha) * lhs + alpha * rhs
Example #10
0
    def _transition(self, ecpoch_per_resolution):
        batch_size = self.di.batch_size
        resolution = self.gen.resolution_list[-1]
        phase = "{}to{}".format(
            self.gen.resolution_list[-2], self.gen.resolution_list[-1])
        logger.info("phase : {}".format(phase))

        kernel_size = self.resolution_list[-1] // resolution
        kernel = (kernel_size, kernel_size)

        total_itr = (self.di.size // batch_size + 1) * ecpoch_per_resolution
        global_itr = 1.
        alpha = global_itr / total_itr

        for epoch in range(ecpoch_per_resolution):
            logger.info("epoch : {}".format(epoch + 1))
            itr = 0
            current_epoch = self.di.epoch
            while self.di.epoch == current_epoch:
                img, _ = self.di.next()
                x = nn.Variable.from_numpy_array(img)

                z = F.randn(shape=(batch_size, self.n_latent, 1, 1))
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen.transition(z, alpha, test=True)
                y.unlinked()
                y.need_grad = False
                x_r = F.average_pooling(x, kernel=kernel)

                p_real = self.dis.transition(x_r, alpha)
                p_fake = self.dis.transition(y, alpha)

                loss_dis = F.mean(F.pow_scalar((p_real - 1), 2.)
                                  + F.pow_scalar(p_fake, 2.) * self.l2_fake_weight)

                if itr % self.n_critic + 1 == self.n_critic:
                    with nn.parameter_scope("discriminator"):
                        self.solver_dis.set_parameters(nn.get_parameters(),
                                                       reset=False, retain_state=True)
                        self.solver_dis.zero_grad()
                        loss_dis.backward(clear_buffer=True)
                        self.solver_dis.update()

                z = F.randn(shape=(batch_size, self.n_latent, 1, 1))
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen.transition(z, alpha, test=False)
                p_fake = self.dis.transition(y, alpha)

                loss_gen = F.mean(F.pow_scalar((p_fake - 1), 2))
                with nn.parameter_scope("generator"):
                    self.solver_gen.set_parameters(
                        nn.get_parameters(), reset=False, retain_state=True)
                    self.solver_gen.zero_grad()
                    loss_gen.backward(clear_buffer=True)
                    self.solver_gen.update()

                itr += 1
                global_itr += 1.
                alpha = global_itr / total_itr

            if epoch % self.save_image_interval + 1 == self.save_image_interval:
                z = nn.Variable.from_numpy_array(self.z_test)
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen.transition(z, alpha)
                img_name = "phase_{}_epoch_{}".format(phase, epoch + 1)
                self.monitor_image_tile.add(
                    img_name, F.unpooling(y, kernel=kernel))
Example #11
0
    def _train(self, ecpoch_per_resolution, each_save=False):
        batch_size = self.di.batch_size
        resolution = self.gen.resolution_list[-1]
        logger.info("phase : {}".format(resolution))

        kernel_size = self.resolution_list[-1] // resolution
        kernel = (kernel_size, kernel_size)

        img_name = "original_phase_{}".format(resolution)
        img, _ = self.di.next()
        self.monitor_image_tile.add(img_name, img)

        for epoch in range(ecpoch_per_resolution):
            logger.info("epoch : {}".format(epoch + 1))
            itr = 0
            current_epoch = self.di.epoch
            while self.di.epoch == current_epoch:
                img, _ = self.di.next()
                x = nn.Variable.from_numpy_array(img)
                z = F.randn(shape=(batch_size, self.n_latent, 1, 1))
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen(z, test=True)

                y.unlinked()
                y.need_grad = False
                x_r = F.average_pooling(x, kernel=kernel)

                p_real = self.dis(x_r)
                p_fake = self.dis(y)
                p_real.persistent, p_fake.persistent = True, True

                loss_dis = F.mean(F.pow_scalar((p_real - 1), 2.)
                                  + F.pow_scalar(p_fake, 2.) * self.l2_fake_weight)
                loss_dis.persistent = True

                if itr % self.n_critic + 1 == self.n_critic:
                    with nn.parameter_scope("discriminator"):
                        self.solver_dis.set_parameters(nn.get_parameters(),
                                                       reset=False, retain_state=True)
                        self.solver_dis.zero_grad()
                        loss_dis.backward(clear_buffer=True)
                        self.solver_dis.update()
                z = F.randn(shape=(batch_size, self.n_latent, 1, 1))
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen(z, test=False)
                p_fake = self.dis(y)
                p_fake.persistent = True

                loss_gen = F.mean(F.pow_scalar((p_fake - 1), 2.))
                loss_gen.persistent = True

                with nn.parameter_scope("generator"):
                    self.solver_gen.set_parameters(nn.get_parameters(),
                                                   reset=False, retain_state=True)
                    self.solver_gen.zero_grad()
                    loss_gen.backward(clear_buffer=True)
                    self.solver_gen.update()

                # Monitor
                self.monitor_p_real.add(
                    self.global_itr, p_real.d.copy().mean())
                self.monitor_p_fake.add(
                    self.global_itr, p_fake.d.copy().mean())
                self.monitor_loss_dis.add(self.global_itr, loss_dis.d.copy())
                self.monitor_loss_gen.add(self.global_itr, loss_gen.d.copy())

                itr += 1
                self.global_itr += 1

            if epoch % self.save_image_interval + 1 == self.save_image_interval:
                z = nn.Variable.from_numpy_array(self.z_test)
                z = pixel_wise_feature_vector_normalization(
                    z) if self.hyper_sphere else z
                y = self.gen(z, test=True)
                img_name = "phase_{}_epoch_{}".format(resolution, epoch + 1)
                self.monitor_image_tile.add(
                    img_name, F.unpooling(y, kernel=kernel))

            if each_save:
                self.gen.save_parameters(self.monitor_path, "Gen_phase_{}_epoch_{}".format(
                    self.resolution_list[-1], epoch+1))
                self.dis.save_parameters(self.monitor_path, "Dis_phase_{}_epoch_{}".format(
                    self.resolution_list[-1], epoch+1))