def forward(self, content, style, alpha=1.0):
        # output pastiche for original input
        style_feats = self.encode_with_intermediate(style)
        content_feat = self.encode(content)
        t = adain(content_feat, style_feats[-1])
        t = alpha * t + (1 - alpha) * content_feat
        g_t = self.decoder(t)
        g_t_feats = self.encode_with_intermediate(g_t)

        # output pastiche for noisy input
        content_noise = self.generate_noisy_input(content, 40, 200)
        with torch.no_grad():
            content_noise_feat = self.encode(content_noise)
            t_noise = adain(content_noise_feat, style_feats[-1])
            t_noise = alpha * t_noise + (1 - alpha) * content_noise_feat
            g_t_noise = self.decoder(t_noise)
            g_t_noise_feats = self.encode_with_intermediate(g_t_noise)

        # calculate losses
        loss_c = self.calc_content_loss(g_t_feats[-1], t)
        loss_n = self.calc_noise_loss(g_t_feats[-1], g_t_noise_feats[-1])
        loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])

        # calculate for all filters
        for i in range(1, 4, 1):
            loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
        return loss_c, loss_s, loss_n
Exemplo n.º 2
0
    def forward(self, content, style, alpha=1.0):
        assert 0 <= alpha <= 1
        style_feats = self.encode_with_intermediate(style)
        content_feat = self.encode(content)
        # move assert
        assert (content_feat.requires_grad is False)
        shape = content_feat.shape
        ### cal attention map  att
        mat = self.attention_conv1(content_feat).view(shape[0], shape[1] // 2,
                                                      -1)
        another_mat = self.attention_conv2(content_feat).view(
            shape[0], shape[1], -1)
        mat = torch.matmul(mat.transpose(1, 2), mat)
        mat = nn.functional.softmax(mat, 2)
        att = torch.matmul(another_mat, mat).reshape(shape)
        content_feat = att * content_feat + content_feat

        t = adain(content_feat, style_feats[-1])
        t = alpha * t + (1 - alpha) * content_feat

        g_t = self.decoder(t)
        g_t_feats = self.encode_with_intermediate(g_t)

        loss_tv = self.tvloss(att)
        loss_c = self.calc_content_loss(g_t_feats[-1], t)
        loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
        loss_cla = self.calc_classify_loss(g_t, style)
        loss_aes = self.calc_aesthetic_loss(g_t, style)
        for i in range(1, 4):
            loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
        return loss_c, loss_s, loss_cla, loss_aes, loss_tv
Exemplo n.º 3
0
    def forward(self, content, style, alpha=1.0):
        assert 0 <= alpha <= 1
        style_feats = self.encode_with_intermediate(style)
        content_feat = self.encode(content)

        ## original
        t = adain(content_feat, style_feats[-1])
        t = alpha * t + (1 - alpha) * content_feat
        g_t = self.decoder(t)
        g_t_feats = self.encode_with_intermediate(g_t)
        loss_c = self.calc_content_loss(g_t_feats[-1],
                                        t)  # ?? why t, not content_feat

        ## mine
        '''
        t = torch.cat((content_feat, style_feats[-1]), dim=1)
        g_t = self.decoder2(t) # decoded image
        g_t_feats = self.encode_with_intermediate(g_t)
        loss_c = self.calc_content_loss(g_t_feats[-1], content_feat.data)
        '''

        loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
        for i in range(1, 4):
            loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
        return loss_c, loss_s, g_t
Exemplo n.º 4
0
    def forward(self, content, style):
        style_feats = self.encode_with_intermediate(style)
        t = adain(self.encode(content), style_feats[-1])

        g_t = self.decoder(Variable(t.data, requires_grad=True))
        g_t_feats = self.encode_with_intermediate(g_t)

        loss_c = self.calc_content_loss(g_t_feats[-1], t)
        loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
        for i in range(1, 4):
            loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
        return loss_c, loss_s
Exemplo n.º 5
0
    def forward(self, content, style, alpha=1.0):
        assert 0 <= alpha <= 1
        style_feats = self.encode_with_intermediate(style)
        content_feat = self.encode(content)
        t = adain(content_feat, style_feats[-1])
        t = alpha * t + (1 - alpha) * content_feat

        g_t = self.decoder(t)
        g_t_feats = self.encode_with_intermediate(g_t)

        loss_c = self.calc_content_loss(g_t_feats[-1], t)
        loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
        for i in range(1, 4):
            loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
        return loss_c, loss_s, g_t
Exemplo n.º 6
0
    def forward(self, content, style, alpha=1.0):
        assert 0 <= alpha <= 1
        # Get image features
        style_feats = self.encode_with_intermediate(style)
        content_feats = self.encode(content)

        # Compute AdaIN and output the image
        t = adain(content_feats, style_feats[-1])
        t = alpha * t + (1 - alpha) * content_feats

        # Output the image
        g_t = self.decoder(t)
        g_t_feats = self.encode_with_intermediate(g_t)

        return t, g_t_feats, style_feats
Exemplo n.º 7
0
        def style_transfer(vgg,
                           decoder,
                           content,
                           style,
                           alpha=1.0,
                           interpolation_weights=None):
            assert (0.0 <= alpha <= 1.0)
            content_f = vgg(content)
            style_f = vgg(style)

            if interpolation_weights:
                pass
            else:
                feat = adain(content_f, style_f)
            feat = feat * alpha + content_f * (1 - alpha)

            return decoder(feat)
Exemplo n.º 8
0
    def forward(self, content, style, alpha=0.9):
        assert 0 <= alpha <= 1
        style_feats = self.encode_with_intermediate(style)
        content_feat = self.encode(content)
        if debug: print("content_feat shape: ", content_feat.shape)
        t = adain(content_feat, style_feats[-1])
        if debug: print("t: ", t.shape)
        t = alpha * t + (1 - alpha) * content_feat

        g_t = self.decoder(t)
        g_t_feats = self.encode_with_intermediate(g_t)
        if debug:
            print("g_t: ", g_t.shape, "   g_t_feat: ", g_t_feats[0].shape)

        loss_c = self.calc_content_loss(g_t_feats[-1], t)
        loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
        for i in range(1, 4):
            loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
        return g_t, loss_c, loss_s
Exemplo n.º 9
0
Arquivo: train.py Projeto: czczup/URST
    loss.backward()
    optimizer.step()

    writer.add_scalar('loss_content', loss_c.item(), i + 1)
    writer.add_scalar('loss_style', loss_s.item(), i + 1)

    if (i + 1) % args.save_image_interval == 0:
        writer.add_image('train/content',
                         image_process(content_images[0]),
                         global_step=i + 1, dataformats='HWC')
        writer.add_image('train/style',
                         image_process(style_images[0]),
                         global_step=i + 1, dataformats='HWC')
        style_feats = network.encode_with_intermediate(style_images[0].unsqueeze(0))
        content_feat = network.encode(content_images[0].unsqueeze(0))
        t = adain(content_feat, style_feats[-1])
        g_t = network.decoder(t)
        writer.add_image('train/stylized',
                         image_process(g_t[0]),
                         global_step=i + 1, dataformats='HWC')

    eta_seconds = ((time.time() - start_time) / (i+1)) * (args.max_iter - (i+1))
    eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
    if (i + 1) % 10 == 0:
        print("[train] Iters: %d/%d || Content Loss: %.2f || Style Loss: %.2f || Estimated Time: %s]"
              %((i+1), args.max_iter, loss_c.item(), loss_s.item(), eta_string))

    if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter:
        state_dict = model.decoder.state_dict()
        for key in state_dict.keys():
            state_dict[key] = state_dict[key].to(torch.device('cpu'))